Spaces:
Running
Running
Upload 2 files
Browse files- dependencies.py +0 -0
- main_code.py +253 -253
dependencies.py
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
main_code.py
CHANGED
|
@@ -8,14 +8,14 @@
|
|
| 8 |
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
|
| 9 |
#
|
| 10 |
#
|
| 11 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 12 |
# you may not use this file except in compliance with the License.
|
| 13 |
# You may obtain a copy of the License at
|
| 14 |
#
|
| 15 |
# http://www.apache.org/licenses/LICENSE-2.0
|
| 16 |
#
|
| 17 |
# Unless required by applicable law or agreed to in writing, software
|
| 18 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 19 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 20 |
# See the License for the specific language governing permissions and
|
| 21 |
# limitations under the License.
|
|
@@ -56,25 +56,25 @@ logger = logging.get_logger(__name__)
|
|
| 56 |
|
| 57 |
@dataclass
|
| 58 |
@auto_docstring(
|
| 59 |
-
custom_intro
|
| 60 |
Base class for Gemma3n outputs, with hidden states and attentions.
|
| 61 |
-
"""
|
| 62 |
)
|
| 63 |
class Gemma3nModelOutputWithPast(BaseModelOutputWithPast):
|
| 64 |
-
r"""
|
| 65 |
-
past_key_values (
|
| 66 |
-
Tuple of
|
| 67 |
-
|
| 68 |
|
| 69 |
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
| 70 |
-
|
| 71 |
-
image_hidden_states (
|
| 72 |
-
A
|
| 73 |
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
|
| 74 |
-
audio_hidden_states (
|
| 75 |
-
A
|
| 76 |
audio_hidden_states of the model produced by the audio encoder and after projecting the last hidden state.
|
| 77 |
-
"""
|
| 78 |
|
| 79 |
image_hidden_states: Optional[torch.FloatTensor] = None
|
| 80 |
|
|
@@ -83,29 +83,29 @@ class Gemma3nModelOutputWithPast(BaseModelOutputWithPast):
|
|
| 83 |
|
| 84 |
@dataclass
|
| 85 |
@auto_docstring(
|
| 86 |
-
custom_intro
|
| 87 |
Base class for Gemma3n causal language model (or autoregressive) outputs.
|
| 88 |
-
"""
|
| 89 |
)
|
| 90 |
class Gemma3nCausalLMOutputWithPast(ModelOutput):
|
| 91 |
-
r"""
|
| 92 |
-
loss (
|
| 93 |
Language modeling loss (for next-token prediction).
|
| 94 |
-
logits (
|
| 95 |
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 96 |
-
past_key_values (
|
| 97 |
-
Tuple of
|
| 98 |
-
|
| 99 |
|
| 100 |
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
| 101 |
-
|
| 102 |
-
image_hidden_states (
|
| 103 |
-
A
|
| 104 |
image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
|
| 105 |
-
audio_hidden_states (
|
| 106 |
-
A
|
| 107 |
audio_hidden_states of the model produced by the audio encoder and after projecting the last hidden state.
|
| 108 |
-
"""
|
| 109 |
|
| 110 |
loss: Optional[torch.FloatTensor] = None
|
| 111 |
logits: Optional[torch.FloatTensor] = None
|
|
@@ -126,7 +126,7 @@ class Gemma3nRMSNorm(nn.Module):
|
|
| 126 |
if self.with_scale:
|
| 127 |
self.weight = nn.Parameter(torch.ones(dim))
|
| 128 |
else:
|
| 129 |
-
self.register_buffer("weight", torch.tensor(1.0), persistent=False)
|
| 130 |
|
| 131 |
def _norm(self, x):
|
| 132 |
return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
|
@@ -138,7 +138,7 @@ class Gemma3nRMSNorm(nn.Module):
|
|
| 138 |
return output.type_as(x)
|
| 139 |
|
| 140 |
def extra_repr(self):
|
| 141 |
-
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
| 142 |
|
| 143 |
|
| 144 |
# ==== Audio Encoder ====
|
|
@@ -163,7 +163,7 @@ class Gemma3nAudioRelativePositionEmbedding(nn.Module):
|
|
| 163 |
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(num_timescales - 1, 1)
|
| 164 |
inv_timescales = min_timescale * torch.exp(torch.arange(num_timescales) * -log_timescale_increment)
|
| 165 |
self.register_buffer(
|
| 166 |
-
"inv_timescales",
|
| 167 |
inv_timescales.float().unsqueeze(0).unsqueeze(0),
|
| 168 |
persistent=False,
|
| 169 |
)
|
|
@@ -184,7 +184,7 @@ class Gemma3nAudioRelativePositionEmbedding(nn.Module):
|
|
| 184 |
key_context_size: int,
|
| 185 |
max_span_plus_1: int,
|
| 186 |
) -> torch.Tensor:
|
| 187 |
-
"""Performs the relative shift.
|
| 188 |
|
| 189 |
Args:
|
| 190 |
term_bd_before_shift: Tensor of shape [B, N, U, W, F_span]. batch_size
|
|
@@ -193,7 +193,7 @@ class Gemma3nAudioRelativePositionEmbedding(nn.Module):
|
|
| 193 |
|
| 194 |
Returns:
|
| 195 |
Tensor of shape [B, N, U, W, C].
|
| 196 |
-
"""
|
| 197 |
# term_bd_before_shift shape: [B, N, U, W, F_span]
|
| 198 |
# Target shape after shift: [B, N, U, W, C]
|
| 199 |
|
|
@@ -209,7 +209,7 @@ class Gemma3nAudioRelativePositionEmbedding(nn.Module):
|
|
| 209 |
term_bd_padded = nn.functional.pad(term_bd_before_shift, padding_tuple)
|
| 210 |
# Shape after pad: [B, N, U, W, C+1]
|
| 211 |
|
| 212 |
-
# Reshape for slicing (emulating JAX's behavior)
|
| 213 |
# [B, N, U, W * (C+1)]
|
| 214 |
term_bd_reshaped = term_bd_padded.reshape(
|
| 215 |
(
|
|
@@ -271,7 +271,7 @@ class Gemma3nAudioRelativePositionEmbedding(nn.Module):
|
|
| 271 |
term_ac = torch.matmul(queries_p, keys_p_t) # [B, N, U, W, C]
|
| 272 |
|
| 273 |
# term_bd: Query-Position interaction
|
| 274 |
-
# Original einsum: term_bd_unshifed = torch.einsum('buwnh,fnh->bnuwf', queries, sin_emb)
|
| 275 |
# queries shape: [B, U, W, N, H]
|
| 276 |
# sin_emb shape: [F, N, H]
|
| 277 |
# Target output shape: [B, N, U, W, F]
|
|
@@ -338,7 +338,7 @@ class Gemma3nAudioAttention(nn.Module):
|
|
| 338 |
|
| 339 |
q_scale = self.head_dim**-0.5
|
| 340 |
r_softplus_0 = 1.0 / torch.nn.functional.softplus(torch.tensor(0.0))
|
| 341 |
-
self.register_buffer("q_scale", (q_scale * r_softplus_0).clone().detach(), persistent=False)
|
| 342 |
|
| 343 |
lower_causal_mask = torch.tril(
|
| 344 |
torch.ones((self.context_size, self.chunk_size), dtype=torch.bool),
|
|
@@ -350,10 +350,10 @@ class Gemma3nAudioAttention(nn.Module):
|
|
| 350 |
)
|
| 351 |
local_causal_valid_mask = torch.ones((self.chunk_size, self.context_size), dtype=torch.bool)
|
| 352 |
local_causal_valid_mask = local_causal_valid_mask * lower_causal_mask * upper_causal_mask
|
| 353 |
-
self.register_buffer("local_causal_valid_mask", local_causal_valid_mask, persistent=False)
|
| 354 |
|
| 355 |
self.register_buffer(
|
| 356 |
-
"softcap",
|
| 357 |
torch.tensor(self.attention_logits_soft_cap).float(),
|
| 358 |
persistent=False,
|
| 359 |
)
|
|
@@ -366,7 +366,7 @@ class Gemma3nAudioAttention(nn.Module):
|
|
| 366 |
return x
|
| 367 |
|
| 368 |
def _convert_to_block(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 369 |
-
"""Turns a sequence to non overlapping blocks.
|
| 370 |
|
| 371 |
Args:
|
| 372 |
hidden_states: a tensor of [batch, time, ...].
|
|
@@ -375,7 +375,7 @@ class Gemma3nAudioAttention(nn.Module):
|
|
| 375 |
A tensor of [batch, num_blocks, block_size, ...], with necessary
|
| 376 |
paddings,
|
| 377 |
where output[:, i, ...] are x[:, i*block_size:(i+1)*block_size, ...].
|
| 378 |
-
"""
|
| 379 |
shape = hidden_states.shape
|
| 380 |
b, t = shape[:2]
|
| 381 |
num_blocks = (t + self.chunk_size - 1) // self.chunk_size
|
|
@@ -388,7 +388,7 @@ class Gemma3nAudioAttention(nn.Module):
|
|
| 388 |
return hidden_states
|
| 389 |
|
| 390 |
def _extract_block_context(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 391 |
-
"""Extracts temporal context for every block.
|
| 392 |
|
| 393 |
Args:
|
| 394 |
hidden_states: a tensor of [batch, time, ...].
|
|
@@ -400,11 +400,11 @@ class Gemma3nAudioAttention(nn.Module):
|
|
| 400 |
and output[:, i, ...] are x[:, start-left_context:end+right_context,
|
| 401 |
...],
|
| 402 |
start = i * block_size, end = (i + 1) * block_size.
|
| 403 |
-
"""
|
| 404 |
pad_left = self.max_past_horizon
|
| 405 |
-
# The JAX equivalent padding for signal.frame with pad_mode
|
| 406 |
# (left_context, right_context + block_size - 1) on the time dimension.
|
| 407 |
-
# PyTorch's _pad_dim1 applies padding symmetrically if only one value is given,
|
| 408 |
# or (pad_dim_start, pad_dim_end) if two are given.
|
| 409 |
# Our _pad_dim1(x, pad_left, pad_right) pads dim -2 (time for [B,T,N,H])
|
| 410 |
# or dim 1 (time for [B,T]).
|
|
@@ -424,7 +424,7 @@ class Gemma3nAudioAttention(nn.Module):
|
|
| 424 |
|
| 425 |
# If x was [B, T_padded], x_unfolded is [B, num_blocks, frame_len]
|
| 426 |
# If x was [B, T_padded, N, H], x_unfolded is [B, num_blocks, N, H, frame_len]
|
| 427 |
-
# We want to match JAX's typical output for such operations which might be
|
| 428 |
# [B, num_blocks, frame_len, N, H] if N, H are present.
|
| 429 |
# The relative_position_embedding expects keys as [B, U, C, N, H].
|
| 430 |
# If x_unfolded is [B, U, N, H, C(frame_len)], we need to move C.
|
|
@@ -436,7 +436,7 @@ class Gemma3nAudioAttention(nn.Module):
|
|
| 436 |
return x_unfolded.contiguous()
|
| 437 |
|
| 438 |
def forward(self, hidden_states: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
|
| 439 |
-
# sl.Dense uses jax.numpy.einsum("...a,abcd->...bcd") and jax.numpy.select()
|
| 440 |
qkv_shape = (*hidden_states.shape[:-1], self.num_heads, self.head_dim)
|
| 441 |
query_states = self.q_proj(hidden_states).reshape(qkv_shape).contiguous()
|
| 442 |
key_states = self.k_proj(hidden_states).reshape(qkv_shape).contiguous()
|
|
@@ -472,7 +472,7 @@ class Gemma3nAudioAttention(nn.Module):
|
|
| 472 |
extracted_valid_mask_blocks = extracted_valid_mask_blocks.reshape(
|
| 473 |
batch_size, num_query_blocks, self.context_size
|
| 474 |
)
|
| 475 |
-
# After potential reshape, ensure it's [B, U, C] if it was from a [B,T] mask.
|
| 476 |
# This assertion might be too strict if _extract_block_context handles higher-rank inputs differently,
|
| 477 |
# but for the mask case, this should hold.
|
| 478 |
if extracted_valid_mask_blocks.shape != (
|
|
@@ -481,9 +481,9 @@ class Gemma3nAudioAttention(nn.Module):
|
|
| 481 |
self.context_size,
|
| 482 |
):
|
| 483 |
raise ValueError(
|
| 484 |
-
"Shape of extracted_valid_mask_blocks"
|
| 485 |
-
f" {extracted_valid_mask_blocks.shape} is not ({batch_size}
|
| 486 |
-
f" {num_query_blocks}, {self.context_size}) after potential reshape
|
| 487 |
)
|
| 488 |
|
| 489 |
# 3. Expand dimensions for broadcasting with logits and causal mask.
|
|
@@ -518,7 +518,7 @@ class Gemma3nAudioAttention(nn.Module):
|
|
| 518 |
logits = torch.where(final_condition_for_where, logits, torch.finfo(logits.dtype).min)
|
| 519 |
probabilities = torch.nn.functional.softmax(logits, dim=-1, dtype=torch.float32).to(dtype=value_blocks.dtype)
|
| 520 |
|
| 521 |
-
# context_vectors is adapted from jax.numpy.einsum("BNuwc,BucNH->BuwNH", ...)
|
| 522 |
b_dim, n_dim, u_dim, w_dim, c_dim = probabilities.shape
|
| 523 |
h_dim = value_blocks.shape[-1]
|
| 524 |
prob_bun = probabilities.permute(0, 2, 1, 3, 4).reshape(-1, w_dim, c_dim)
|
|
@@ -539,21 +539,21 @@ class Gemma3nAudioAttention(nn.Module):
|
|
| 539 |
|
| 540 |
|
| 541 |
class Gemma3nAudioCumulativeGroupNorm(nn.Module):
|
| 542 |
-
"""Applies Group Normalization cumulatively over the time dimension.
|
| 543 |
|
| 544 |
This layer normalizes the input by calculating the mean and variance
|
| 545 |
cumulatively over the time dimension (dim 1). The statistics are computed
|
| 546 |
-
over all feature dimensions (specified by
|
| 547 |
-
for elements marked as valid by the optional
|
| 548 |
|
| 549 |
-
If a
|
| 550 |
invalid time steps do not contribute to the statistics calculation, and
|
| 551 |
their corresponding output values are zeroed out.
|
| 552 |
|
| 553 |
Scale and bias, if enabled, are applied per-channel (last dimension).
|
| 554 |
-
This behavior is similar to JAX's
|
| 555 |
-
and
|
| 556 |
-
"""
|
| 557 |
|
| 558 |
def __init__(
|
| 559 |
self,
|
|
@@ -574,19 +574,19 @@ class Gemma3nAudioCumulativeGroupNorm(nn.Module):
|
|
| 574 |
self.reduction_axes = tuple(range(2, 2 + len(self.feature_dims) + 1))
|
| 575 |
|
| 576 |
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 577 |
-
"""Applies cumulative group norm, optionally using a mask.
|
| 578 |
|
| 579 |
Args:
|
| 580 |
hidden_states: Input tensor, shape [B, T, *feature_dims, C].
|
| 581 |
|
| 582 |
Returns:
|
| 583 |
Normalized tensor with the same shape as x.
|
| 584 |
-
"""
|
| 585 |
expected_input_suffix = self.feature_dims + (self.num_channels,)
|
| 586 |
if hidden_states.shape[2:] != expected_input_suffix:
|
| 587 |
raise ValueError(
|
| 588 |
-
f"Input tensor shape suffix {hidden_states.shape[2:]} does not match expected"
|
| 589 |
-
f" suffix (feature_dims + num_channels) {expected_input_suffix}"
|
| 590 |
)
|
| 591 |
|
| 592 |
input_dtype = hidden_states.dtype
|
|
@@ -594,7 +594,7 @@ class Gemma3nAudioCumulativeGroupNorm(nn.Module):
|
|
| 594 |
calc_dtype = torch.float32
|
| 595 |
x_calc = hidden_states.to(calc_dtype)
|
| 596 |
|
| 597 |
-
# Prepare a broadcastable mask (
|
| 598 |
# If no mask is provided, treat all elements as valid
|
| 599 |
# (mask_calc is all ones).
|
| 600 |
# Otherwise, expand the [B, T] mask to [B, T, 1, ..., 1] for broadcasting.
|
|
@@ -607,7 +607,7 @@ class Gemma3nAudioCumulativeGroupNorm(nn.Module):
|
|
| 607 |
cum_sum_values = torch.cumsum(sum_values_at_t, dim=1)
|
| 608 |
|
| 609 |
# 3. Count of valid elements in the normalization group at each time step.
|
| 610 |
-
# (A "group" here consists of all features at a given Batch, Time).
|
| 611 |
elements_in_group_at_t = torch.sum(mask_calc, dim=self.reduction_axes, keepdim=True)
|
| 612 |
# 4. Cumulative count of valid elements over time.
|
| 613 |
cum_count_elements = torch.cumsum(elements_in_group_at_t, dim=1)
|
|
@@ -648,11 +648,11 @@ class Gemma3nAudioCumulativeGroupNorm(nn.Module):
|
|
| 648 |
|
| 649 |
|
| 650 |
class Gemma3nAudioSSCPConvBlock(nn.Module):
|
| 651 |
-
"""A single convolution block for the SubSampleConvProjection.
|
| 652 |
|
| 653 |
This block consists of a 2D convolution, followed by CumulativeGroupNorm,
|
| 654 |
and a ReLU activation. It handles manual padding for the convolution.
|
| 655 |
-
"""
|
| 656 |
|
| 657 |
def __init__(
|
| 658 |
self,
|
|
@@ -665,7 +665,7 @@ class Gemma3nAudioSSCPConvBlock(nn.Module):
|
|
| 665 |
self.config = config
|
| 666 |
self.manual_padding = manual_padding
|
| 667 |
|
| 668 |
-
# in_channels is 1 for the first block, or C_out from previous block's conv
|
| 669 |
in_channels = 1 if idx == 0 else self.config.sscp_conv_channel_size[idx - 1]
|
| 670 |
out_channels = self.config.sscp_conv_channel_size[idx]
|
| 671 |
kernel_h, kernel_w = self.config.sscp_conv_kernel_size[idx]
|
|
@@ -701,7 +701,7 @@ class Gemma3nAudioSSCPConvBlock(nn.Module):
|
|
| 701 |
# Input audio_encodings is [B, C_in, T_in, F_in] (e.g., C_in=1)
|
| 702 |
# manual_padding is (pad_F_left, pad_F_right, pad_T_top, pad_T_bottom)
|
| 703 |
# F.pad applies to last two dims: F_in then T_in
|
| 704 |
-
audio_encodings_padded = F.pad(audio_encodings, self.manual_padding, mode
|
| 705 |
# Expected padded shape for F_in, k_w=3, pad_F=(1,1) -> F_padded = F_in+2
|
| 706 |
# Expected padded shape for T_in, k_h=3, pad_T=(0,2) -> T_padded = T_in+2
|
| 707 |
audio_encodings_conv = self.conv(audio_encodings_padded)
|
|
@@ -728,7 +728,7 @@ class Gemma3nAudioSubSampleConvProjection(nn.Module):
|
|
| 728 |
stride_h, stride_w = config.sscp_conv_stride_size[i]
|
| 729 |
|
| 730 |
# Padding for Time (Height for Conv2d) - REVERSE_CAUSAL like
|
| 731 |
-
# JAX 'reverse_causal' padding is (0, kernel_size - 1)
|
| 732 |
pad_t_top = 0
|
| 733 |
pad_t_bottom = kernel_h - 1
|
| 734 |
|
|
@@ -736,7 +736,7 @@ class Gemma3nAudioSubSampleConvProjection(nn.Module):
|
|
| 736 |
# Based on JAX effective padding (1,1) for F_in=10, K_w=3, S_w=2
|
| 737 |
# and the successful test configuration.
|
| 738 |
# If kernel/stride/input_freq for frequency changes, this might need re-evaluation
|
| 739 |
-
# to match generic JAX 'SAME' behavior if it differs.
|
| 740 |
pad_f_left = 1
|
| 741 |
pad_f_right = 1
|
| 742 |
|
|
@@ -792,7 +792,7 @@ class Gemma3nAudioConformerAttention(nn.Module):
|
|
| 792 |
super().__init__()
|
| 793 |
self.config = config
|
| 794 |
self.post_in_features = self.config.hidden_size
|
| 795 |
-
self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False)
|
| 796 |
self.pre_attn_norm = Gemma3nRMSNorm(self.config.hidden_size)
|
| 797 |
self.attn = Gemma3nAudioAttention(config)
|
| 798 |
self.post = nn.Linear(self.post_in_features, self.config.hidden_size, bias=False)
|
|
@@ -820,7 +820,7 @@ class Gemma3nAudioConformerFeedForward(nn.Module):
|
|
| 820 |
super().__init__()
|
| 821 |
self.config = config
|
| 822 |
|
| 823 |
-
self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False)
|
| 824 |
|
| 825 |
self.pre_layer_norm = Gemma3nRMSNorm(self.config.hidden_size)
|
| 826 |
self.ffw_layer_1 = nn.Linear(self.config.hidden_size, self.config.hidden_size * 4, bias=False)
|
|
@@ -856,7 +856,7 @@ class Gemma3nAudioConformerLightConv1d(nn.Module):
|
|
| 856 |
groups=self.config.hidden_size, # Depthwise
|
| 857 |
bias=False,
|
| 858 |
)
|
| 859 |
-
self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False)
|
| 860 |
self.conv_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
|
| 861 |
self.linear_end = nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False)
|
| 862 |
|
|
@@ -892,7 +892,7 @@ class Gemma3nAudioConformerBlock(nn.Module):
|
|
| 892 |
self.attention = Gemma3nAudioConformerAttention(self.config)
|
| 893 |
self.lconv1d = Gemma3nAudioConformerLightConv1d(self.config)
|
| 894 |
self.ffw_layer_end = Gemma3nAudioConformerFeedForward(self.config)
|
| 895 |
-
self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False)
|
| 896 |
self.norm = Gemma3nRMSNorm(self.config.hidden_size)
|
| 897 |
|
| 898 |
def forward(self, audio_encodings: torch.Tensor, audio_mel_mask: torch.BoolTensor) -> torch.Tensor:
|
|
@@ -911,11 +911,11 @@ class Gemma3nAudioConformerBlock(nn.Module):
|
|
| 911 |
|
| 912 |
|
| 913 |
class Gemma3nAudioEncoder(PreTrainedModel):
|
| 914 |
-
"""An audio encoder based on the [Universal Speech Model](https://arxiv.org/abs/2303.01037) architecture
|
| 915 |
|
| 916 |
config_class = Gemma3nAudioConfig
|
| 917 |
|
| 918 |
-
main_input_name = "audio_mel"
|
| 919 |
|
| 920 |
def __init__(self, config: Gemma3nAudioConfig):
|
| 921 |
super().__init__(config)
|
|
@@ -929,7 +929,7 @@ class Gemma3nAudioEncoder(PreTrainedModel):
|
|
| 929 |
def forward(
|
| 930 |
self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor
|
| 931 |
) -> tuple[torch.Tensor, torch.BoolTensor]:
|
| 932 |
-
"""Encodes a batch of MELs.
|
| 933 |
|
| 934 |
Args:
|
| 935 |
audio_mel: a torch.Tensor of shape [batch, num_frames, num_channels,
|
|
@@ -937,10 +937,10 @@ class Gemma3nAudioEncoder(PreTrainedModel):
|
|
| 937 |
|
| 938 |
Returns:
|
| 939 |
audio_encodings: a torch.Tensor of shape
|
| 940 |
-
|
| 941 |
-
self.config.audio_config.hidden_size]
|
| 942 |
audio_mel_mask: a torch.BoolTensor of shape [batch, num_frames].
|
| 943 |
-
"""
|
| 944 |
audio_encodings = self.subsample_conv_projection(audio_mel) # audio_encodings: [B, T_sub, D]
|
| 945 |
|
| 946 |
# Subsample the input audio_mel_mask to match the time dimension of audio_encodings (T_sub)
|
|
@@ -983,20 +983,20 @@ class Gemma3nAudioEncoder(PreTrainedModel):
|
|
| 983 |
|
| 984 |
|
| 985 |
class Gemma3nTextScaledWordEmbedding(nn.Embedding):
|
| 986 |
-
"""
|
| 987 |
-
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
|
| 988 |
-
"""
|
| 989 |
|
| 990 |
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0):
|
| 991 |
super().__init__(num_embeddings, embedding_dim, padding_idx)
|
| 992 |
-
self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)
|
| 993 |
|
| 994 |
def forward(self, input_ids: torch.Tensor):
|
| 995 |
return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)
|
| 996 |
|
| 997 |
|
| 998 |
class Gemma3nTextLaurelBlock(nn.Module):
|
| 999 |
-
"""Learned Augmented Residual Layer"""
|
| 1000 |
|
| 1001 |
def __init__(self, config: Gemma3nTextConfig):
|
| 1002 |
super().__init__()
|
|
@@ -1052,16 +1052,16 @@ class Gemma3nTextMLP(nn.Module):
|
|
| 1052 |
|
| 1053 |
|
| 1054 |
class Gemma3nTextAltUp(nn.Module):
|
| 1055 |
-
"""Alternating Updates (AltUp)
|
| 1056 |
|
| 1057 |
-
The AltUp module wraps transformer layers. The
|
| 1058 |
-
input to the transformer layer, and the
|
| 1059 |
of the transformer layer to the sparsely updated dimensions.
|
| 1060 |
|
| 1061 |
See more in the research paper:
|
| 1062 |
|
| 1063 |
https://proceedings.neurips.cc/paper_files/paper/2023/file/f2059277ac6ce66e7e5543001afa8bb5-Paper-Conference.pdf
|
| 1064 |
-
"""
|
| 1065 |
|
| 1066 |
def __init__(self, config: Gemma3nTextConfig):
|
| 1067 |
super().__init__()
|
|
@@ -1071,7 +1071,7 @@ class Gemma3nTextAltUp(nn.Module):
|
|
| 1071 |
self.prediction_coefs = nn.Linear(self.config.altup_num_inputs, self.config.altup_num_inputs**2, bias=False)
|
| 1072 |
self.modality_router = nn.Linear(self.config.hidden_size, self.config.altup_num_inputs, bias=False)
|
| 1073 |
self.router_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
|
| 1074 |
-
self.register_buffer("router_input_scale", torch.tensor(self.config.hidden_size**-1.0), persistent=False)
|
| 1075 |
|
| 1076 |
def compute_router_modalities(self, x: torch.Tensor) -> torch.Tensor:
|
| 1077 |
router_inputs = self.router_norm(x) * self.router_input_scale
|
|
@@ -1079,15 +1079,15 @@ class Gemma3nTextAltUp(nn.Module):
|
|
| 1079 |
return torch.tanh(routed.float()).type_as(x)
|
| 1080 |
|
| 1081 |
def predict(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 1082 |
-
"""Predicts the output of a layer using a trainable map.
|
| 1083 |
|
| 1084 |
Args:
|
| 1085 |
-
hidden_states: A 4D tensor of shape
|
| 1086 |
-
stacking the input embeddings and preprocessing the last
|
| 1087 |
|
| 1088 |
Returns:
|
| 1089 |
-
A 4D tensor of shape
|
| 1090 |
-
"""
|
| 1091 |
modalities = self.compute_router_modalities(hidden_states[self.config.altup_active_idx])
|
| 1092 |
|
| 1093 |
if self.training and self.config.altup_coef_clip is not None:
|
|
@@ -1107,17 +1107,17 @@ class Gemma3nTextAltUp(nn.Module):
|
|
| 1107 |
return predictions.contiguous().type_as(hidden_states)
|
| 1108 |
|
| 1109 |
def correct(self, predictions: torch.Tensor, activated: torch.Tensor) -> torch.Tensor:
|
| 1110 |
-
"""Corrects the predictions relative to the
|
| 1111 |
|
| 1112 |
Args:
|
| 1113 |
-
predictions: A 4D tensor of shape
|
| 1114 |
-
stacking the input embeddings and preprocessing the last
|
| 1115 |
-
activated: A 3D tensor of shape
|
| 1116 |
|
| 1117 |
Returns:
|
| 1118 |
-
A 4D tensor of shape
|
| 1119 |
predictions relative to the activated input embeddings.
|
| 1120 |
-
"""
|
| 1121 |
modalities = self.compute_router_modalities(activated)
|
| 1122 |
innovation = activated - predictions[self.config.altup_active_idx] # (batch, num_tokens, hidden_size)
|
| 1123 |
innovation = innovation.repeat(self.config.altup_num_inputs, 1, 1, 1) # Repeat on dim0 to match predictions
|
|
@@ -1125,7 +1125,7 @@ class Gemma3nTextAltUp(nn.Module):
|
|
| 1125 |
if self.config.altup_coef_clip is not None:
|
| 1126 |
self.correction_coefs.weight.data.clamp_(-self.config.altup_coef_clip, self.config.altup_coef_clip)
|
| 1127 |
|
| 1128 |
-
# all_coefs adapted from jax.numpy.einsum("...p,pi->...i", ...)
|
| 1129 |
# Permute to (altup_num_inputs, batch_size, num_tokens) as the last dim is a scalar applied to each altup input
|
| 1130 |
# and expand on dim1 for broadcastability
|
| 1131 |
all_coefs: torch.Tensor = self.correction_coefs(modalities) + 1.0
|
|
@@ -1136,26 +1136,26 @@ class Gemma3nTextAltUp(nn.Module):
|
|
| 1136 |
return corrected.contiguous().type_as(activated)
|
| 1137 |
|
| 1138 |
def forward(self, corrected: torch.Tensor) -> torch.Tensor:
|
| 1139 |
-
"""
|
| 1140 |
-
This is only defined as the
|
| 1141 |
(which is a nn.Parameter, not a Module) between devices when offloading. It is otherwise only used in
|
| 1142 |
-
|
| 1143 |
-
"""
|
| 1144 |
return (corrected.type_as(self.correct_output_scale) * self.correct_output_scale).type_as(corrected)
|
| 1145 |
|
| 1146 |
def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor:
|
| 1147 |
-
"""Scales the provided 3D tensor of shape [batch_size, num_tokens, hidden_size]
|
| 1148 |
return self.forward(corrected)
|
| 1149 |
|
| 1150 |
|
| 1151 |
class Gemma3nTextRotaryEmbedding(nn.Module):
|
| 1152 |
def __init__(self, config: Gemma3nTextConfig, device=None):
|
| 1153 |
super().__init__()
|
| 1154 |
-
# BC: "rope_type" was originally "type"
|
| 1155 |
-
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
| 1156 |
-
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
| 1157 |
else:
|
| 1158 |
-
self.rope_type = "default"
|
| 1159 |
self.max_seq_len_cached = config.max_position_embeddings
|
| 1160 |
self.original_max_seq_len = config.max_position_embeddings
|
| 1161 |
|
|
@@ -1163,7 +1163,7 @@ class Gemma3nTextRotaryEmbedding(nn.Module):
|
|
| 1163 |
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 1164 |
|
| 1165 |
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
| 1166 |
-
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 1167 |
self.original_inv_freq = self.inv_freq
|
| 1168 |
|
| 1169 |
@torch.no_grad()
|
|
@@ -1172,7 +1172,7 @@ class Gemma3nTextRotaryEmbedding(nn.Module):
|
|
| 1172 |
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
| 1173 |
position_ids_expanded = position_ids[:, None, :].float()
|
| 1174 |
|
| 1175 |
-
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
| 1176 |
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
| 1177 |
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 1178 |
emb = torch.cat((freqs, freqs), dim=-1)
|
|
@@ -1183,17 +1183,17 @@ class Gemma3nTextRotaryEmbedding(nn.Module):
|
|
| 1183 |
|
| 1184 |
|
| 1185 |
def rotate_half(x):
|
| 1186 |
-
"""Rotates half the hidden dims of the input
|
| 1187 |
x1 = x[..., : x.shape[-1] // 2]
|
| 1188 |
x2 = x[..., x.shape[-1] // 2 :]
|
| 1189 |
return torch.cat((-x2, x1), dim=-1)
|
| 1190 |
|
| 1191 |
|
| 1192 |
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 1193 |
-
"""
|
| 1194 |
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 1195 |
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 1196 |
-
"""
|
| 1197 |
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 1198 |
if n_rep == 1:
|
| 1199 |
return hidden_states
|
|
@@ -1243,38 +1243,38 @@ def apply_rotary_pos_emb(
|
|
| 1243 |
position_ids: Optional[torch.Tensor] = None,
|
| 1244 |
unsqueeze_dim: int = 1,
|
| 1245 |
):
|
| 1246 |
-
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 1247 |
|
| 1248 |
Args:
|
| 1249 |
-
x (
|
| 1250 |
-
cos (
|
| 1251 |
-
sin (
|
| 1252 |
-
position_ids (
|
| 1253 |
Deprecated and unused.
|
| 1254 |
-
unsqueeze_dim (
|
| 1255 |
-
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 1256 |
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 1257 |
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 1258 |
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 1259 |
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 1260 |
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 1261 |
Returns:
|
| 1262 |
-
|
| 1263 |
-
"""
|
| 1264 |
cos = cos.unsqueeze(unsqueeze_dim)
|
| 1265 |
sin = sin.unsqueeze(unsqueeze_dim)
|
| 1266 |
return (x * cos) + (rotate_half(x) * sin)
|
| 1267 |
|
| 1268 |
|
| 1269 |
class Gemma3nTextAttention(nn.Module):
|
| 1270 |
-
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 1271 |
|
| 1272 |
def __init__(self, config: Gemma3nTextConfig, layer_idx: int):
|
| 1273 |
super().__init__()
|
| 1274 |
-
self.is_sliding = config.layer_types[layer_idx] == "sliding_attention"
|
| 1275 |
self.config = config
|
| 1276 |
self.layer_idx = layer_idx
|
| 1277 |
-
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
| 1278 |
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
| 1279 |
self.attention_dropout = self.config.attention_dropout
|
| 1280 |
self.is_causal = True
|
|
@@ -1356,15 +1356,15 @@ class Gemma3nTextAttention(nn.Module):
|
|
| 1356 |
if past_key_value is not None:
|
| 1357 |
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 1358 |
cache_kwargs = {
|
| 1359 |
-
"sin": sin,
|
| 1360 |
-
"cos": cos,
|
| 1361 |
-
"cache_position": cache_position,
|
| 1362 |
-
"sliding_window": self.sliding_window,
|
| 1363 |
}
|
| 1364 |
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 1365 |
|
| 1366 |
attention_interface: Callable = eager_attention_forward
|
| 1367 |
-
if self.config._attn_implementation != "eager":
|
| 1368 |
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 1369 |
|
| 1370 |
attn_output, attn_weights = attention_interface(
|
|
@@ -1407,7 +1407,7 @@ class Gemma3nTextDecoderLayer(GradientCheckpointingLayer):
|
|
| 1407 |
self.per_layer_projection = nn.Linear(self.hidden_size_per_layer_input, self.hidden_size, bias=False)
|
| 1408 |
self.post_per_layer_input_norm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
| 1409 |
|
| 1410 |
-
@deprecate_kwarg("last_cache_position", version
|
| 1411 |
def forward(
|
| 1412 |
self,
|
| 1413 |
hidden_states: torch.Tensor,
|
|
@@ -1460,12 +1460,12 @@ class Gemma3nTextDecoderLayer(GradientCheckpointingLayer):
|
|
| 1460 |
if self.config.altup_correct_scale:
|
| 1461 |
first_prediction = self.altup.scale_corrected_output(first_prediction)
|
| 1462 |
|
| 1463 |
-
# per_layer_input_gate adapted from jax.numpy.einsum("btd,dp->btp", ...)
|
| 1464 |
first_prediction = self.per_layer_input_gate(first_prediction)
|
| 1465 |
first_prediction = self.act_fn(first_prediction)
|
| 1466 |
first_prediction = torch.multiply(first_prediction, per_layer_input)
|
| 1467 |
|
| 1468 |
-
# per_layer_projection adapted from jax.numpy.einsum("btp,pd->btd", ...)
|
| 1469 |
first_prediction = self.per_layer_projection(first_prediction)
|
| 1470 |
first_prediction = self.post_per_layer_input_norm(first_prediction)
|
| 1471 |
corrected_predictions[1:] += first_prediction
|
|
@@ -1481,10 +1481,10 @@ class Gemma3nTextDecoderLayer(GradientCheckpointingLayer):
|
|
| 1481 |
@auto_docstring
|
| 1482 |
class Gemma3nPreTrainedModel(PreTrainedModel):
|
| 1483 |
config_class = Gemma3nConfig
|
| 1484 |
-
base_model_prefix = ""
|
| 1485 |
supports_gradient_checkpointing = True
|
| 1486 |
-
_no_split_modules = ["Gemma3nTextDecoderLayer"]
|
| 1487 |
-
_skip_keys_device_placement = ["past_key_values"]
|
| 1488 |
_supports_flash_attn_3 = True
|
| 1489 |
_supports_flash_attn_2 = True
|
| 1490 |
_supports_sdpa = True
|
|
@@ -1495,9 +1495,9 @@ class Gemma3nPreTrainedModel(PreTrainedModel):
|
|
| 1495 |
_supports_attention_backend = True
|
| 1496 |
|
| 1497 |
def _init_weights(self, module):
|
| 1498 |
-
# important: this ported version of Gemma2 isn't meant for training from scratch - only
|
| 1499 |
# inference and fine-tuning - so the proper init weights code has been removed
|
| 1500 |
-
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
|
| 1501 |
|
| 1502 |
if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d)):
|
| 1503 |
module.weight.data.normal_(mean=0.0, std=std)
|
|
@@ -1518,7 +1518,7 @@ class Gemma3nPreTrainedModel(PreTrainedModel):
|
|
| 1518 |
module.correct_output_scale.data.zero_()
|
| 1519 |
|
| 1520 |
|
| 1521 |
-
@auto_docstring(custom_intro
|
| 1522 |
class Gemma3nTextModel(Gemma3nPreTrainedModel):
|
| 1523 |
config_class = Gemma3nTextConfig
|
| 1524 |
|
|
@@ -1544,7 +1544,7 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
|
|
| 1544 |
# defaults should hold values for global RoPE.
|
| 1545 |
config = copy.deepcopy(config)
|
| 1546 |
config.rope_theta = config.rope_local_base_freq
|
| 1547 |
-
config.rope_scaling = {"rope_type": "default"}
|
| 1548 |
self.rotary_emb_local = Gemma3nTextRotaryEmbedding(config=config)
|
| 1549 |
|
| 1550 |
self.hidden_size = config.hidden_size
|
|
@@ -1573,8 +1573,8 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
|
|
| 1573 |
[nn.Linear(self.hidden_size, self.hidden_size, bias=False) for _ in range(1, self.config.altup_num_inputs)]
|
| 1574 |
)
|
| 1575 |
|
| 1576 |
-
self.register_buffer("per_layer_projection_scale", torch.tensor(self.hidden_size**-0.5), persistent=False)
|
| 1577 |
-
self.register_buffer("per_layer_input_scale", torch.rsqrt(torch.tensor(2.0)), persistent=False)
|
| 1578 |
|
| 1579 |
# Initialize weights and apply final processing
|
| 1580 |
self.post_init()
|
|
@@ -1601,10 +1601,10 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
|
|
| 1601 |
cache_position: Optional[torch.LongTensor] = None,
|
| 1602 |
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
| 1603 |
) -> BaseModelOutputWithPast:
|
| 1604 |
-
r"""
|
| 1605 |
per_layer_inputs (torch.Tensor, *optional*, defaults to None):
|
| 1606 |
Pre-computed per-layer embeddings. If None, they are derived from input_ids if provided.
|
| 1607 |
-
"""
|
| 1608 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1609 |
output_hidden_states = (
|
| 1610 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
@@ -1612,11 +1612,11 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
|
|
| 1612 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 1613 |
|
| 1614 |
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 1615 |
-
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 1616 |
|
| 1617 |
if self.gradient_checkpointing and self.training and use_cache:
|
| 1618 |
logger.warning_once(
|
| 1619 |
-
"
|
| 1620 |
)
|
| 1621 |
use_cache = False
|
| 1622 |
|
|
@@ -1640,20 +1640,20 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
|
|
| 1640 |
if position_ids is None:
|
| 1641 |
position_ids = cache_position.unsqueeze(0)
|
| 1642 |
|
| 1643 |
-
# It may already have been prepared by e.g.
|
| 1644 |
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
| 1645 |
# Prepare mask arguments
|
| 1646 |
mask_kwargs = {
|
| 1647 |
-
"config": self.config,
|
| 1648 |
-
"input_embeds": inputs_embeds,
|
| 1649 |
-
"attention_mask": attention_mask,
|
| 1650 |
-
"cache_position": cache_position,
|
| 1651 |
-
"past_key_values": past_key_values,
|
| 1652 |
}
|
| 1653 |
# Create the masks
|
| 1654 |
causal_mask_mapping = {
|
| 1655 |
-
"full_attention": create_causal_mask(**mask_kwargs),
|
| 1656 |
-
"sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
|
| 1657 |
}
|
| 1658 |
|
| 1659 |
# embed positions
|
|
@@ -1669,7 +1669,7 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
|
|
| 1669 |
|
| 1670 |
temp_hidden_states = [hidden_states_0]
|
| 1671 |
for i in range(1, self.config.altup_num_inputs):
|
| 1672 |
-
# altup_proj adapted from jax.numpy.einsum("btp,pd->btd", ...)
|
| 1673 |
altup_proj = self.altup_projections[i - 1](hidden_states_0)
|
| 1674 |
current_hidden_state = altup_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device)
|
| 1675 |
new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True)
|
|
@@ -1717,7 +1717,7 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
|
|
| 1717 |
target_magnitude = torch.mean(hidden_states[0] ** 2, dim=-1, keepdim=True) ** 0.5
|
| 1718 |
temp_hidden_states = [hidden_states[0]]
|
| 1719 |
for i in range(1, self.config.altup_num_inputs):
|
| 1720 |
-
# altup_unembed_projections adapted from jax.numpy.einsum("btp,pd->btd", ...)
|
| 1721 |
altup_unemb_proj: torch.Tensor = self.altup_unembed_projections[i - 1](hidden_states[i])
|
| 1722 |
current_hidden_state = altup_unemb_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device)
|
| 1723 |
new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True)
|
|
@@ -1771,14 +1771,14 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
|
|
| 1771 |
)
|
| 1772 |
|
| 1773 |
|
| 1774 |
-
@auto_docstring(custom_intro
|
| 1775 |
class Gemma3nForCausalLM(Gemma3nPreTrainedModel, GenerationMixin):
|
| 1776 |
-
_tied_weights_keys = ["lm_head.weight"]
|
| 1777 |
-
_tp_plan = {"lm_head": "colwise_rep"}
|
| 1778 |
-
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
| 1779 |
config_class = Gemma3nTextConfig
|
| 1780 |
-
base_model_prefix = "model"
|
| 1781 |
-
_checkpoint_conversion_mapping = {"model.language_model": "model"}
|
| 1782 |
|
| 1783 |
def __init__(self, config: Gemma3nTextConfig):
|
| 1784 |
super().__init__(config)
|
|
@@ -1824,33 +1824,33 @@ class Gemma3nForCausalLM(Gemma3nPreTrainedModel, GenerationMixin):
|
|
| 1824 |
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 1825 |
**loss_kwargs,
|
| 1826 |
) -> CausalLMOutputWithPast:
|
| 1827 |
-
r"""
|
| 1828 |
-
labels (
|
| 1829 |
-
Labels for computing the masked language modeling loss. Indices should either be in
|
| 1830 |
-
config.vocab_size]
|
| 1831 |
-
(masked), the loss is only computed for the tokens with labels in
|
| 1832 |
|
| 1833 |
Example:
|
| 1834 |
|
| 1835 |
-
|
| 1836 |
>>> from transformers import AutoTokenizer, Gemma3nForCausalLM
|
| 1837 |
|
| 1838 |
-
>>> model = Gemma3nForCausalLM.from_pretrained("google/gemma-2-9b")
|
| 1839 |
-
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
|
| 1840 |
|
| 1841 |
-
>>> prompt = "What is your favorite condiment
|
| 1842 |
-
>>> inputs = tokenizer(prompt, return_tensors
|
| 1843 |
|
| 1844 |
>>> # Generate
|
| 1845 |
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
| 1846 |
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 1847 |
-
"What is your favorite condiment
|
| 1848 |
-
|
| 1849 |
|
| 1850 |
-
if self.training and self.config._attn_implementation != "eager":
|
| 1851 |
logger.warning_once(
|
| 1852 |
-
"It is strongly recommended to train Gemma3n models with the
|
| 1853 |
-
f"instead of
|
| 1854 |
)
|
| 1855 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1856 |
output_hidden_states = (
|
|
@@ -1893,7 +1893,7 @@ class Gemma3nForCausalLM(Gemma3nPreTrainedModel, GenerationMixin):
|
|
| 1893 |
|
| 1894 |
|
| 1895 |
class Gemma3nMultimodalEmbedder(nn.Module):
|
| 1896 |
-
"""Embeds token ids or soft tokens for multimodal content into language model space
|
| 1897 |
|
| 1898 |
def __init__(
|
| 1899 |
self,
|
|
@@ -1919,18 +1919,18 @@ class Gemma3nMultimodalEmbedder(nn.Module):
|
|
| 1919 |
input_ids: Optional[torch.LongTensor] = None,
|
| 1920 |
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1921 |
) -> torch.Tensor:
|
| 1922 |
-
"""Embeds token ids or soft tokens for multimodal content into language model space.
|
| 1923 |
|
| 1924 |
Args:
|
| 1925 |
input_ids: A torch.LongTensor containing the token ids to embed. Values should be in the range
|
| 1926 |
-
|
| 1927 |
inputs_embeds: A torch.Tensor containing the soft tokens to embed.
|
| 1928 |
|
| 1929 |
Returns:
|
| 1930 |
-
A torch.Tensor of embeddings with shape
|
| 1931 |
-
"""
|
| 1932 |
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 1933 |
-
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 1934 |
|
| 1935 |
if inputs_embeds is not None:
|
| 1936 |
emb_norm = self.soft_embedding_norm(inputs_embeds)
|
|
@@ -1943,14 +1943,14 @@ class Gemma3nMultimodalEmbedder(nn.Module):
|
|
| 1943 |
|
| 1944 |
|
| 1945 |
@auto_docstring(
|
| 1946 |
-
custom_intro
|
| 1947 |
The base Gemma 3n model comprising a vision backbone, an audio backbone, and a language model without a
|
| 1948 |
language modeling head.
|
| 1949 |
-
"""
|
| 1950 |
)
|
| 1951 |
class Gemma3nModel(Gemma3nPreTrainedModel):
|
| 1952 |
_checkpoint_conversion_mapping = {}
|
| 1953 |
-
# we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
|
| 1954 |
accepts_loss_kwargs = False
|
| 1955 |
|
| 1956 |
def __init__(self, config: Gemma3nConfig):
|
|
@@ -1981,16 +1981,16 @@ class Gemma3nModel(Gemma3nPreTrainedModel):
|
|
| 1981 |
return self.language_model
|
| 1982 |
|
| 1983 |
def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
| 1984 |
-
"""
|
| 1985 |
Projects the last hidden state from the vision model into language model space.
|
| 1986 |
|
| 1987 |
Args:
|
| 1988 |
-
pixel_values (
|
| 1989 |
The tensors corresponding to the input images.
|
| 1990 |
|
| 1991 |
Returns:
|
| 1992 |
-
image_features (
|
| 1993 |
-
"""
|
| 1994 |
vision_outputs = self.vision_tower(
|
| 1995 |
pixel_values=pixel_values, do_pooling=False, return_dict=True
|
| 1996 |
).last_hidden_state
|
|
@@ -2024,36 +2024,36 @@ class Gemma3nModel(Gemma3nPreTrainedModel):
|
|
| 2024 |
output_hidden_states: Optional[bool] = None,
|
| 2025 |
**lm_kwargs,
|
| 2026 |
) -> Gemma3nCausalLMOutputWithPast:
|
| 2027 |
-
r"""
|
| 2028 |
-
labels (
|
| 2029 |
-
Labels for computing the masked language modeling loss. Indices should either be in
|
| 2030 |
-
config.text_config.vocab_size]
|
| 2031 |
-
(masked), the loss is only computed for the tokens with labels in
|
| 2032 |
|
| 2033 |
Example:
|
| 2034 |
|
| 2035 |
-
|
| 2036 |
>>> from PIL import Image
|
| 2037 |
>>> import requests
|
| 2038 |
>>> from transformers import AutoProcessor, Gemma3nForConditionalGeneration
|
| 2039 |
|
| 2040 |
-
>>> model = Gemma3nForConditionalGeneration.from_pretrained("google/gemma3n2-3b-mix-224")
|
| 2041 |
-
>>> processor = AutoProcessor.from_pretrained("google/gemma3n2-3b-mix-224")
|
| 2042 |
|
| 2043 |
-
>>> prompt = "Where is the cat standing
|
| 2044 |
-
>>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
|
| 2045 |
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 2046 |
|
| 2047 |
-
>>> inputs = processor(images=image, text=prompt, return_tensors
|
| 2048 |
|
| 2049 |
>>> # Generate
|
| 2050 |
>>> generate_ids = model.generate(**inputs,)
|
| 2051 |
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 2052 |
-
"Where is the cat standing?\nsnow"
|
| 2053 |
-
|
| 2054 |
-
"""
|
| 2055 |
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 2056 |
-
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 2057 |
|
| 2058 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 2059 |
output_hidden_states = (
|
|
@@ -2103,9 +2103,9 @@ class Gemma3nModel(Gemma3nPreTrainedModel):
|
|
| 2103 |
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
| 2104 |
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
| 2105 |
raise ValueError(
|
| 2106 |
-
f"Number of images does not match number of special image tokens in the input text. "
|
| 2107 |
-
f"Got {image_tokens_in_text} image tokens in the text and "
|
| 2108 |
-
f"{image_features.shape[0] * image_features.shape[1]} tokens from image embeddings
|
| 2109 |
)
|
| 2110 |
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
| 2111 |
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
|
@@ -2140,9 +2140,9 @@ class Gemma3nModel(Gemma3nPreTrainedModel):
|
|
| 2140 |
if not is_torchdynamo_compiling() and inputs_embeds[special_audio_mask].numel() != audio_features.numel():
|
| 2141 |
audio_tokens_in_text = (special_audio_mask).sum(dim=1).sum(dim=0)[0]
|
| 2142 |
raise ValueError(
|
| 2143 |
-
f"Number of audio input features does not match number of special audio tokens in the input text. "
|
| 2144 |
-
f"Got {audio_tokens_in_text} audio tokens in the text and "
|
| 2145 |
-
f"{audio_features.shape[0] * audio_features.shape[1]} tokens from audio embeddings
|
| 2146 |
)
|
| 2147 |
audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
| 2148 |
inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features)
|
|
@@ -2174,32 +2174,32 @@ class Gemma3nModel(Gemma3nPreTrainedModel):
|
|
| 2174 |
def get_audio_features(
|
| 2175 |
self, input_features: torch.Tensor, input_features_mask: torch.Tensor
|
| 2176 |
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 2177 |
-
"""
|
| 2178 |
Projects the last hidden state from the audio encoder into language model space.
|
| 2179 |
|
| 2180 |
Args:
|
| 2181 |
-
input_features (
|
| 2182 |
The tensors corresponding to the input audio.
|
| 2183 |
-
input_features (
|
| 2184 |
The attention mask for the input audio.
|
| 2185 |
|
| 2186 |
Returns:
|
| 2187 |
-
audio_features (
|
| 2188 |
-
"""
|
| 2189 |
audio_outputs, audio_mask = self.audio_tower(input_features, input_features_mask)
|
| 2190 |
return self.embed_audio(inputs_embeds=audio_outputs), audio_mask
|
| 2191 |
|
| 2192 |
|
| 2193 |
@auto_docstring(
|
| 2194 |
-
custom_intro
|
| 2195 |
The base Gemma 3n model comprising a vision backbone, an audio backbone, a language model, and a language modeling
|
| 2196 |
head.
|
| 2197 |
-
"""
|
| 2198 |
)
|
| 2199 |
class Gemma3nForConditionalGeneration(Gemma3nPreTrainedModel, GenerationMixin):
|
| 2200 |
_checkpoint_conversion_mapping = {}
|
| 2201 |
-
_tied_weights_keys = ["lm_head.weight"]
|
| 2202 |
-
base_model_prefix = "model"
|
| 2203 |
|
| 2204 |
def __init__(self, config: Gemma3nConfig):
|
| 2205 |
super().__init__(config)
|
|
@@ -2239,7 +2239,7 @@ class Gemma3nForConditionalGeneration(Gemma3nPreTrainedModel, GenerationMixin):
|
|
| 2239 |
|
| 2240 |
@property
|
| 2241 |
def multi_modal_projector(self):
|
| 2242 |
-
raise AttributeError("Use embed_vision instead of multi_modal_projector
|
| 2243 |
|
| 2244 |
@can_return_tuple
|
| 2245 |
@auto_docstring
|
|
@@ -2262,38 +2262,38 @@ class Gemma3nForConditionalGeneration(Gemma3nPreTrainedModel, GenerationMixin):
|
|
| 2262 |
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 2263 |
**lm_kwargs,
|
| 2264 |
) -> Gemma3nCausalLMOutputWithPast:
|
| 2265 |
-
r"""
|
| 2266 |
input_features (torch.Tensor, *optional*, defaults to None):
|
| 2267 |
The audio inputs to be encoded.
|
| 2268 |
input_features_mask (torch.Tensor, *optional*, defaults to None):
|
| 2269 |
The attention mask for the input audio.
|
| 2270 |
-
labels (
|
| 2271 |
-
Labels for computing the masked language modeling loss. Indices should either be in
|
| 2272 |
-
config.text_config.vocab_size]
|
| 2273 |
ignored (masked), the loss is only computed for the tokens with labels in
|
| 2274 |
-
|
| 2275 |
|
| 2276 |
Example:
|
| 2277 |
|
| 2278 |
-
|
| 2279 |
>>> from PIL import Image
|
| 2280 |
>>> import requests
|
| 2281 |
>>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
|
| 2282 |
|
| 2283 |
-
>>> model = Gemma3ForConditionalGeneration.from_pretrained("google/gemma-3-4b-it")
|
| 2284 |
-
>>> processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")
|
| 2285 |
|
| 2286 |
>>> messages = [
|
| 2287 |
... {
|
| 2288 |
-
... "role": "system",
|
| 2289 |
-
... "content": [
|
| 2290 |
-
... {"type": "text", "text": "You are a helpful assistant
|
| 2291 |
... ]
|
| 2292 |
... },
|
| 2293 |
... {
|
| 2294 |
-
... "role": "user", "content": [
|
| 2295 |
-
... {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
|
| 2296 |
-
... {"type": "text", "text": "Where is the cat standing
|
| 2297 |
... ]
|
| 2298 |
... },
|
| 2299 |
... ]
|
|
@@ -2302,15 +2302,15 @@ class Gemma3nForConditionalGeneration(Gemma3nPreTrainedModel, GenerationMixin):
|
|
| 2302 |
... messages,
|
| 2303 |
... tokenizer=True,
|
| 2304 |
... return_dict=True,
|
| 2305 |
-
... return_tensors
|
| 2306 |
... add_generation_prompt=True
|
| 2307 |
... )
|
| 2308 |
>>> # Generate
|
| 2309 |
>>> generate_ids = model.generate(**inputs)
|
| 2310 |
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 2311 |
-
"user\nYou are a helpful assistant.\n\n\n\n\n\nWhere is the cat standing?\nmodel\nBased on the image, the cat is standing in a snowy area, likely outdoors. It appears to"
|
| 2312 |
-
|
| 2313 |
-
"""
|
| 2314 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 2315 |
output_hidden_states = (
|
| 2316 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
@@ -2393,7 +2393,7 @@ class Gemma3nForConditionalGeneration(Gemma3nPreTrainedModel, GenerationMixin):
|
|
| 2393 |
labels=None,
|
| 2394 |
**kwargs,
|
| 2395 |
):
|
| 2396 |
-
# Overwritten -- custom
|
| 2397 |
model_inputs = super().prepare_inputs_for_generation(
|
| 2398 |
input_ids,
|
| 2399 |
past_key_values=past_key_values,
|
|
@@ -2407,13 +2407,13 @@ class Gemma3nForConditionalGeneration(Gemma3nPreTrainedModel, GenerationMixin):
|
|
| 2407 |
**kwargs,
|
| 2408 |
)
|
| 2409 |
|
| 2410 |
-
# If we're in cached decoding stage, multimodal inputs should be None because input ids do not contain special
|
| 2411 |
# tokens anymore. Otherwise multimodal inputs should be passed to model.
|
| 2412 |
# NOTE: use_cache=False always needs pixel_values, input_features, and input_features_mask
|
| 2413 |
if cache_position[0] == 0:
|
| 2414 |
-
model_inputs["pixel_values"] = pixel_values
|
| 2415 |
-
model_inputs["input_features"] = input_features
|
| 2416 |
-
model_inputs["input_features_mask"] = input_features_mask
|
| 2417 |
|
| 2418 |
return model_inputs
|
| 2419 |
|
|
@@ -2423,10 +2423,10 @@ class Gemma3nForConditionalGeneration(Gemma3nPreTrainedModel, GenerationMixin):
|
|
| 2423 |
|
| 2424 |
|
| 2425 |
__all__ = [
|
| 2426 |
-
"Gemma3nAudioEncoder",
|
| 2427 |
-
"Gemma3nForCausalLM",
|
| 2428 |
-
"Gemma3nForConditionalGeneration",
|
| 2429 |
-
"Gemma3nModel",
|
| 2430 |
-
"Gemma3nPreTrainedModel",
|
| 2431 |
-
"Gemma3nTextModel",
|
| 2432 |
]
|
|
|
|
| 8 |
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
|
| 9 |
#
|
| 10 |
#
|
| 11 |
+
# Licensed under the Apache License, Version 2.0 (the \"License\");
|
| 12 |
# you may not use this file except in compliance with the License.
|
| 13 |
# You may obtain a copy of the License at
|
| 14 |
#
|
| 15 |
# http://www.apache.org/licenses/LICENSE-2.0
|
| 16 |
#
|
| 17 |
# Unless required by applicable law or agreed to in writing, software
|
| 18 |
+
# distributed under the License is distributed on an \"AS IS\" BASIS,
|
| 19 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 20 |
# See the License for the specific language governing permissions and
|
| 21 |
# limitations under the License.
|
|
|
|
| 56 |
|
| 57 |
@dataclass
|
| 58 |
@auto_docstring(
|
| 59 |
+
custom_intro=\"\"\"
|
| 60 |
Base class for Gemma3n outputs, with hidden states and attentions.
|
| 61 |
+
\"\"\"
|
| 62 |
)
|
| 63 |
class Gemma3nModelOutputWithPast(BaseModelOutputWithPast):
|
| 64 |
+
r\"\"\"
|
| 65 |
+
past_key_values (\`tuple(tuple(torch.FloatTensor))\`, *optional*, returned when \`use_cache=True\` is passed or when \`config.use_cache=True\`):
|
| 66 |
+
Tuple of \`tuple(torch.FloatTensor)\` of length \`config.n_layers\`, with each tuple having 2 tensors of shape
|
| 67 |
+
\`(batch_size, num_heads, sequence_length, embed_size_per_head)\`)
|
| 68 |
|
| 69 |
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
| 70 |
+
\`past_key_values\` input) to speed up sequential decoding.
|
| 71 |
+
image_hidden_states (\`torch.FloatTensor\`, *optional*):
|
| 72 |
+
A \`torch.FloatTensor\` of size \`(batch_size, num_images, sequence_length, hidden_size)\`.
|
| 73 |
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
|
| 74 |
+
audio_hidden_states (\`torch.FloatTensor\`, *optional*):
|
| 75 |
+
A \`torch.FloatTensor\` of size \`(batch_size, num_images, sequence_length, hidden_size)\`.
|
| 76 |
audio_hidden_states of the model produced by the audio encoder and after projecting the last hidden state.
|
| 77 |
+
\"\"\"
|
| 78 |
|
| 79 |
image_hidden_states: Optional[torch.FloatTensor] = None
|
| 80 |
|
|
|
|
| 83 |
|
| 84 |
@dataclass
|
| 85 |
@auto_docstring(
|
| 86 |
+
custom_intro=\"\"\"
|
| 87 |
Base class for Gemma3n causal language model (or autoregressive) outputs.
|
| 88 |
+
\"\"\"
|
| 89 |
)
|
| 90 |
class Gemma3nCausalLMOutputWithPast(ModelOutput):
|
| 91 |
+
r\"\"\"
|
| 92 |
+
loss (\`torch.FloatTensor\` of shape \`(1,)\`, *optional*, returned when \`labels\` is provided):
|
| 93 |
Language modeling loss (for next-token prediction).
|
| 94 |
+
logits (\`torch.FloatTensor\` of shape \`(batch_size, sequence_length, config.text_config.vocab_size)\`):
|
| 95 |
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 96 |
+
past_key_values (\`tuple(tuple(torch.FloatTensor))\`, *optional*, returned when \`use_cache=True\` is passed or when \`config.use_cache=True\`):
|
| 97 |
+
Tuple of \`tuple(torch.FloatTensor)\` of length \`config.n_layers\`, with each tuple having 2 tensors of shape
|
| 98 |
+
\`(batch_size, num_heads, sequence_length, embed_size_per_head)\`)
|
| 99 |
|
| 100 |
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
| 101 |
+
\`past_key_values\` input) to speed up sequential decoding.
|
| 102 |
+
image_hidden_states (\`torch.FloatTensor\`, *optional*):
|
| 103 |
+
A \`torch.FloatTensor\` of size \`(batch_size, num_images, sequence_length, hidden_size)\`.
|
| 104 |
image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
|
| 105 |
+
audio_hidden_states (\`torch.FloatTensor\`, *optional*):
|
| 106 |
+
A \`torch.FloatTensor\` of size \`(batch_size, num_images, sequence_length, hidden_size)\`.
|
| 107 |
audio_hidden_states of the model produced by the audio encoder and after projecting the last hidden state.
|
| 108 |
+
\"\"\"
|
| 109 |
|
| 110 |
loss: Optional[torch.FloatTensor] = None
|
| 111 |
logits: Optional[torch.FloatTensor] = None
|
|
|
|
| 126 |
if self.with_scale:
|
| 127 |
self.weight = nn.Parameter(torch.ones(dim))
|
| 128 |
else:
|
| 129 |
+
self.register_buffer(\"weight\", torch.tensor(1.0), persistent=False)
|
| 130 |
|
| 131 |
def _norm(self, x):
|
| 132 |
return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
|
|
|
| 138 |
return output.type_as(x)
|
| 139 |
|
| 140 |
def extra_repr(self):
|
| 141 |
+
return f\"{tuple(self.weight.shape)}, eps={self.eps}\"
|
| 142 |
|
| 143 |
|
| 144 |
# ==== Audio Encoder ====
|
|
|
|
| 163 |
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(num_timescales - 1, 1)
|
| 164 |
inv_timescales = min_timescale * torch.exp(torch.arange(num_timescales) * -log_timescale_increment)
|
| 165 |
self.register_buffer(
|
| 166 |
+
\"inv_timescales\",
|
| 167 |
inv_timescales.float().unsqueeze(0).unsqueeze(0),
|
| 168 |
persistent=False,
|
| 169 |
)
|
|
|
|
| 184 |
key_context_size: int,
|
| 185 |
max_span_plus_1: int,
|
| 186 |
) -> torch.Tensor:
|
| 187 |
+
\"\"\"Performs the relative shift.
|
| 188 |
|
| 189 |
Args:
|
| 190 |
term_bd_before_shift: Tensor of shape [B, N, U, W, F_span]. batch_size
|
|
|
|
| 193 |
|
| 194 |
Returns:
|
| 195 |
Tensor of shape [B, N, U, W, C].
|
| 196 |
+
\"\"\"
|
| 197 |
# term_bd_before_shift shape: [B, N, U, W, F_span]
|
| 198 |
# Target shape after shift: [B, N, U, W, C]
|
| 199 |
|
|
|
|
| 209 |
term_bd_padded = nn.functional.pad(term_bd_before_shift, padding_tuple)
|
| 210 |
# Shape after pad: [B, N, U, W, C+1]
|
| 211 |
|
| 212 |
+
# Reshape for slicing (emulating JAX\'s behavior)
|
| 213 |
# [B, N, U, W * (C+1)]
|
| 214 |
term_bd_reshaped = term_bd_padded.reshape(
|
| 215 |
(
|
|
|
|
| 271 |
term_ac = torch.matmul(queries_p, keys_p_t) # [B, N, U, W, C]
|
| 272 |
|
| 273 |
# term_bd: Query-Position interaction
|
| 274 |
+
# Original einsum: term_bd_unshifed = torch.einsum(\'buwnh,fnh->bnuwf\', queries, sin_emb)
|
| 275 |
# queries shape: [B, U, W, N, H]
|
| 276 |
# sin_emb shape: [F, N, H]
|
| 277 |
# Target output shape: [B, N, U, W, F]
|
|
|
|
| 338 |
|
| 339 |
q_scale = self.head_dim**-0.5
|
| 340 |
r_softplus_0 = 1.0 / torch.nn.functional.softplus(torch.tensor(0.0))
|
| 341 |
+
self.register_buffer(\"q_scale\", (q_scale * r_softplus_0).clone().detach(), persistent=False)
|
| 342 |
|
| 343 |
lower_causal_mask = torch.tril(
|
| 344 |
torch.ones((self.context_size, self.chunk_size), dtype=torch.bool),
|
|
|
|
| 350 |
)
|
| 351 |
local_causal_valid_mask = torch.ones((self.chunk_size, self.context_size), dtype=torch.bool)
|
| 352 |
local_causal_valid_mask = local_causal_valid_mask * lower_causal_mask * upper_causal_mask
|
| 353 |
+
self.register_buffer(\"local_causal_valid_mask\", local_causal_valid_mask, persistent=False)
|
| 354 |
|
| 355 |
self.register_buffer(
|
| 356 |
+
\"softcap\",
|
| 357 |
torch.tensor(self.attention_logits_soft_cap).float(),
|
| 358 |
persistent=False,
|
| 359 |
)
|
|
|
|
| 366 |
return x
|
| 367 |
|
| 368 |
def _convert_to_block(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 369 |
+
\"\"\"Turns a sequence to non overlapping blocks.
|
| 370 |
|
| 371 |
Args:
|
| 372 |
hidden_states: a tensor of [batch, time, ...].
|
|
|
|
| 375 |
A tensor of [batch, num_blocks, block_size, ...], with necessary
|
| 376 |
paddings,
|
| 377 |
where output[:, i, ...] are x[:, i*block_size:(i+1)*block_size, ...].
|
| 378 |
+
\"\"\"
|
| 379 |
shape = hidden_states.shape
|
| 380 |
b, t = shape[:2]
|
| 381 |
num_blocks = (t + self.chunk_size - 1) // self.chunk_size
|
|
|
|
| 388 |
return hidden_states
|
| 389 |
|
| 390 |
def _extract_block_context(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 391 |
+
\"\"\"Extracts temporal context for every block.
|
| 392 |
|
| 393 |
Args:
|
| 394 |
hidden_states: a tensor of [batch, time, ...].
|
|
|
|
| 400 |
and output[:, i, ...] are x[:, start-left_context:end+right_context,
|
| 401 |
...],
|
| 402 |
start = i * block_size, end = (i + 1) * block_size.
|
| 403 |
+
\"\"\"
|
| 404 |
pad_left = self.max_past_horizon
|
| 405 |
+
# The JAX equivalent padding for signal.frame with pad_mode=\'valid\' is
|
| 406 |
# (left_context, right_context + block_size - 1) on the time dimension.
|
| 407 |
+
# PyTorch\'s _pad_dim1 applies padding symmetrically if only one value is given,
|
| 408 |
# or (pad_dim_start, pad_dim_end) if two are given.
|
| 409 |
# Our _pad_dim1(x, pad_left, pad_right) pads dim -2 (time for [B,T,N,H])
|
| 410 |
# or dim 1 (time for [B,T]).
|
|
|
|
| 424 |
|
| 425 |
# If x was [B, T_padded], x_unfolded is [B, num_blocks, frame_len]
|
| 426 |
# If x was [B, T_padded, N, H], x_unfolded is [B, num_blocks, N, H, frame_len]
|
| 427 |
+
# We want to match JAX\'s typical output for such operations which might be
|
| 428 |
# [B, num_blocks, frame_len, N, H] if N, H are present.
|
| 429 |
# The relative_position_embedding expects keys as [B, U, C, N, H].
|
| 430 |
# If x_unfolded is [B, U, N, H, C(frame_len)], we need to move C.
|
|
|
|
| 436 |
return x_unfolded.contiguous()
|
| 437 |
|
| 438 |
def forward(self, hidden_states: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
|
| 439 |
+
# sl.Dense uses jax.numpy.einsum(\"...a,abcd->...bcd\") and jax.numpy.select()
|
| 440 |
qkv_shape = (*hidden_states.shape[:-1], self.num_heads, self.head_dim)
|
| 441 |
query_states = self.q_proj(hidden_states).reshape(qkv_shape).contiguous()
|
| 442 |
key_states = self.k_proj(hidden_states).reshape(qkv_shape).contiguous()
|
|
|
|
| 472 |
extracted_valid_mask_blocks = extracted_valid_mask_blocks.reshape(
|
| 473 |
batch_size, num_query_blocks, self.context_size
|
| 474 |
)
|
| 475 |
+
# After potential reshape, ensure it\'s [B, U, C] if it was from a [B,T] mask.
|
| 476 |
# This assertion might be too strict if _extract_block_context handles higher-rank inputs differently,
|
| 477 |
# but for the mask case, this should hold.
|
| 478 |
if extracted_valid_mask_blocks.shape != (
|
|
|
|
| 481 |
self.context_size,
|
| 482 |
):
|
| 483 |
raise ValueError(
|
| 484 |
+
\"Shape of extracted_valid_mask_blocks\"
|
| 485 |
+
f\" {extracted_valid_mask_blocks.shape} is not ({batch_size},\"
|
| 486 |
+
f\" {num_query_blocks}, {self.context_size}) after potential reshape.\"
|
| 487 |
)
|
| 488 |
|
| 489 |
# 3. Expand dimensions for broadcasting with logits and causal mask.
|
|
|
|
| 518 |
logits = torch.where(final_condition_for_where, logits, torch.finfo(logits.dtype).min)
|
| 519 |
probabilities = torch.nn.functional.softmax(logits, dim=-1, dtype=torch.float32).to(dtype=value_blocks.dtype)
|
| 520 |
|
| 521 |
+
# context_vectors is adapted from jax.numpy.einsum(\"BNuwc,BucNH->BuwNH\", ...)
|
| 522 |
b_dim, n_dim, u_dim, w_dim, c_dim = probabilities.shape
|
| 523 |
h_dim = value_blocks.shape[-1]
|
| 524 |
prob_bun = probabilities.permute(0, 2, 1, 3, 4).reshape(-1, w_dim, c_dim)
|
|
|
|
| 539 |
|
| 540 |
|
| 541 |
class Gemma3nAudioCumulativeGroupNorm(nn.Module):
|
| 542 |
+
\"\"\"Applies Group Normalization cumulatively over the time dimension.
|
| 543 |
|
| 544 |
This layer normalizes the input by calculating the mean and variance
|
| 545 |
cumulatively over the time dimension (dim 1). The statistics are computed
|
| 546 |
+
over all feature dimensions (specified by \`feature_dims\` and \`num_channels\`)
|
| 547 |
+
for elements marked as valid by the optional \`mask\`.
|
| 548 |
|
| 549 |
+
If a \`mask\` is provided (True for valid, False for invalid/padded),
|
| 550 |
invalid time steps do not contribute to the statistics calculation, and
|
| 551 |
their corresponding output values are zeroed out.
|
| 552 |
|
| 553 |
Scale and bias, if enabled, are applied per-channel (last dimension).
|
| 554 |
+
This behavior is similar to JAX\'s \`GroupNormalization\` with \`num_groups=1\`
|
| 555 |
+
and \`cumulative=True\`.
|
| 556 |
+
\"\"\"
|
| 557 |
|
| 558 |
def __init__(
|
| 559 |
self,
|
|
|
|
| 574 |
self.reduction_axes = tuple(range(2, 2 + len(self.feature_dims) + 1))
|
| 575 |
|
| 576 |
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 577 |
+
\"\"\"Applies cumulative group norm, optionally using a mask.
|
| 578 |
|
| 579 |
Args:
|
| 580 |
hidden_states: Input tensor, shape [B, T, *feature_dims, C].
|
| 581 |
|
| 582 |
Returns:
|
| 583 |
Normalized tensor with the same shape as x.
|
| 584 |
+
\"\"\"
|
| 585 |
expected_input_suffix = self.feature_dims + (self.num_channels,)
|
| 586 |
if hidden_states.shape[2:] != expected_input_suffix:
|
| 587 |
raise ValueError(
|
| 588 |
+
f\"Input tensor shape suffix {hidden_states.shape[2:]} does not match expected\"
|
| 589 |
+
f\" suffix (feature_dims + num_channels) {expected_input_suffix}\"
|
| 590 |
)
|
| 591 |
|
| 592 |
input_dtype = hidden_states.dtype
|
|
|
|
| 594 |
calc_dtype = torch.float32
|
| 595 |
x_calc = hidden_states.to(calc_dtype)
|
| 596 |
|
| 597 |
+
# Prepare a broadcastable mask (\`mask_calc\`).
|
| 598 |
# If no mask is provided, treat all elements as valid
|
| 599 |
# (mask_calc is all ones).
|
| 600 |
# Otherwise, expand the [B, T] mask to [B, T, 1, ..., 1] for broadcasting.
|
|
|
|
| 607 |
cum_sum_values = torch.cumsum(sum_values_at_t, dim=1)
|
| 608 |
|
| 609 |
# 3. Count of valid elements in the normalization group at each time step.
|
| 610 |
+
# (A \"group\" here consists of all features at a given Batch, Time).
|
| 611 |
elements_in_group_at_t = torch.sum(mask_calc, dim=self.reduction_axes, keepdim=True)
|
| 612 |
# 4. Cumulative count of valid elements over time.
|
| 613 |
cum_count_elements = torch.cumsum(elements_in_group_at_t, dim=1)
|
|
|
|
| 648 |
|
| 649 |
|
| 650 |
class Gemma3nAudioSSCPConvBlock(nn.Module):
|
| 651 |
+
\"\"\"A single convolution block for the SubSampleConvProjection.
|
| 652 |
|
| 653 |
This block consists of a 2D convolution, followed by CumulativeGroupNorm,
|
| 654 |
and a ReLU activation. It handles manual padding for the convolution.
|
| 655 |
+
\"\"\"
|
| 656 |
|
| 657 |
def __init__(
|
| 658 |
self,
|
|
|
|
| 665 |
self.config = config
|
| 666 |
self.manual_padding = manual_padding
|
| 667 |
|
| 668 |
+
# in_channels is 1 for the first block, or C_out from previous block\'s conv
|
| 669 |
in_channels = 1 if idx == 0 else self.config.sscp_conv_channel_size[idx - 1]
|
| 670 |
out_channels = self.config.sscp_conv_channel_size[idx]
|
| 671 |
kernel_h, kernel_w = self.config.sscp_conv_kernel_size[idx]
|
|
|
|
| 701 |
# Input audio_encodings is [B, C_in, T_in, F_in] (e.g., C_in=1)
|
| 702 |
# manual_padding is (pad_F_left, pad_F_right, pad_T_top, pad_T_bottom)
|
| 703 |
# F.pad applies to last two dims: F_in then T_in
|
| 704 |
+
audio_encodings_padded = F.pad(audio_encodings, self.manual_padding, mode=\"constant\", value=0.0)
|
| 705 |
# Expected padded shape for F_in, k_w=3, pad_F=(1,1) -> F_padded = F_in+2
|
| 706 |
# Expected padded shape for T_in, k_h=3, pad_T=(0,2) -> T_padded = T_in+2
|
| 707 |
audio_encodings_conv = self.conv(audio_encodings_padded)
|
|
|
|
| 728 |
stride_h, stride_w = config.sscp_conv_stride_size[i]
|
| 729 |
|
| 730 |
# Padding for Time (Height for Conv2d) - REVERSE_CAUSAL like
|
| 731 |
+
# JAX \'reverse_causal\' padding is (0, kernel_size - 1)
|
| 732 |
pad_t_top = 0
|
| 733 |
pad_t_bottom = kernel_h - 1
|
| 734 |
|
|
|
|
| 736 |
# Based on JAX effective padding (1,1) for F_in=10, K_w=3, S_w=2
|
| 737 |
# and the successful test configuration.
|
| 738 |
# If kernel/stride/input_freq for frequency changes, this might need re-evaluation
|
| 739 |
+
# to match generic JAX \'SAME\' behavior if it differs.
|
| 740 |
pad_f_left = 1
|
| 741 |
pad_f_right = 1
|
| 742 |
|
|
|
|
| 792 |
super().__init__()
|
| 793 |
self.config = config
|
| 794 |
self.post_in_features = self.config.hidden_size
|
| 795 |
+
self.register_buffer(\"gradient_clipping\", torch.tensor(self.config.gradient_clipping), persistent=False)
|
| 796 |
self.pre_attn_norm = Gemma3nRMSNorm(self.config.hidden_size)
|
| 797 |
self.attn = Gemma3nAudioAttention(config)
|
| 798 |
self.post = nn.Linear(self.post_in_features, self.config.hidden_size, bias=False)
|
|
|
|
| 820 |
super().__init__()
|
| 821 |
self.config = config
|
| 822 |
|
| 823 |
+
self.register_buffer(\"gradient_clipping\", torch.tensor(self.config.gradient_clipping), persistent=False)
|
| 824 |
|
| 825 |
self.pre_layer_norm = Gemma3nRMSNorm(self.config.hidden_size)
|
| 826 |
self.ffw_layer_1 = nn.Linear(self.config.hidden_size, self.config.hidden_size * 4, bias=False)
|
|
|
|
| 856 |
groups=self.config.hidden_size, # Depthwise
|
| 857 |
bias=False,
|
| 858 |
)
|
| 859 |
+
self.register_buffer(\"gradient_clipping\", torch.tensor(self.config.gradient_clipping), persistent=False)
|
| 860 |
self.conv_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
|
| 861 |
self.linear_end = nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False)
|
| 862 |
|
|
|
|
| 892 |
self.attention = Gemma3nAudioConformerAttention(self.config)
|
| 893 |
self.lconv1d = Gemma3nAudioConformerLightConv1d(self.config)
|
| 894 |
self.ffw_layer_end = Gemma3nAudioConformerFeedForward(self.config)
|
| 895 |
+
self.register_buffer(\"gradient_clipping\", torch.tensor(self.config.gradient_clipping), persistent=False)
|
| 896 |
self.norm = Gemma3nRMSNorm(self.config.hidden_size)
|
| 897 |
|
| 898 |
def forward(self, audio_encodings: torch.Tensor, audio_mel_mask: torch.BoolTensor) -> torch.Tensor:
|
|
|
|
| 911 |
|
| 912 |
|
| 913 |
class Gemma3nAudioEncoder(PreTrainedModel):
|
| 914 |
+
\"\"\"An audio encoder based on the [Universal Speech Model](https://arxiv.org/abs/2303.01037) architecture.\"\"\"
|
| 915 |
|
| 916 |
config_class = Gemma3nAudioConfig
|
| 917 |
|
| 918 |
+
main_input_name = \"audio_mel\"
|
| 919 |
|
| 920 |
def __init__(self, config: Gemma3nAudioConfig):
|
| 921 |
super().__init__(config)
|
|
|
|
| 929 |
def forward(
|
| 930 |
self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor
|
| 931 |
) -> tuple[torch.Tensor, torch.BoolTensor]:
|
| 932 |
+
\"\"\"Encodes a batch of MELs.
|
| 933 |
|
| 934 |
Args:
|
| 935 |
audio_mel: a torch.Tensor of shape [batch, num_frames, num_channels,
|
|
|
|
| 937 |
|
| 938 |
Returns:
|
| 939 |
audio_encodings: a torch.Tensor of shape
|
| 940 |
+
\`[batch_size, self.config.audio_soft_tokens_per_image,
|
| 941 |
+
self.config.audio_config.hidden_size]\`
|
| 942 |
audio_mel_mask: a torch.BoolTensor of shape [batch, num_frames].
|
| 943 |
+
\"\"\"
|
| 944 |
audio_encodings = self.subsample_conv_projection(audio_mel) # audio_encodings: [B, T_sub, D]
|
| 945 |
|
| 946 |
# Subsample the input audio_mel_mask to match the time dimension of audio_encodings (T_sub)
|
|
|
|
| 983 |
|
| 984 |
|
| 985 |
class Gemma3nTextScaledWordEmbedding(nn.Embedding):
|
| 986 |
+
\"\"\"
|
| 987 |
+
This module overrides nn.Embeddings\' forward by multiplying with embeddings scale.
|
| 988 |
+
\"\"\"
|
| 989 |
|
| 990 |
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0):
|
| 991 |
super().__init__(num_embeddings, embedding_dim, padding_idx)
|
| 992 |
+
self.register_buffer(\"embed_scale\", torch.tensor(embed_scale), persistent=False)
|
| 993 |
|
| 994 |
def forward(self, input_ids: torch.Tensor):
|
| 995 |
return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)
|
| 996 |
|
| 997 |
|
| 998 |
class Gemma3nTextLaurelBlock(nn.Module):
|
| 999 |
+
\"\"\"Learned Augmented Residual Layer\"\"\"
|
| 1000 |
|
| 1001 |
def __init__(self, config: Gemma3nTextConfig):
|
| 1002 |
super().__init__()
|
|
|
|
| 1052 |
|
| 1053 |
|
| 1054 |
class Gemma3nTextAltUp(nn.Module):
|
| 1055 |
+
\"\"\"Alternating Updates (AltUp)
|
| 1056 |
|
| 1057 |
+
The AltUp module wraps transformer layers. The \`predict\` step modifies the
|
| 1058 |
+
input to the transformer layer, and the \`correct\` step propagates the output
|
| 1059 |
of the transformer layer to the sparsely updated dimensions.
|
| 1060 |
|
| 1061 |
See more in the research paper:
|
| 1062 |
|
| 1063 |
https://proceedings.neurips.cc/paper_files/paper/2023/file/f2059277ac6ce66e7e5543001afa8bb5-Paper-Conference.pdf
|
| 1064 |
+
\"\"\"
|
| 1065 |
|
| 1066 |
def __init__(self, config: Gemma3nTextConfig):
|
| 1067 |
super().__init__()
|
|
|
|
| 1071 |
self.prediction_coefs = nn.Linear(self.config.altup_num_inputs, self.config.altup_num_inputs**2, bias=False)
|
| 1072 |
self.modality_router = nn.Linear(self.config.hidden_size, self.config.altup_num_inputs, bias=False)
|
| 1073 |
self.router_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
|
| 1074 |
+
self.register_buffer(\"router_input_scale\", torch.tensor(self.config.hidden_size**-1.0), persistent=False)
|
| 1075 |
|
| 1076 |
def compute_router_modalities(self, x: torch.Tensor) -> torch.Tensor:
|
| 1077 |
router_inputs = self.router_norm(x) * self.router_input_scale
|
|
|
|
| 1079 |
return torch.tanh(routed.float()).type_as(x)
|
| 1080 |
|
| 1081 |
def predict(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 1082 |
+
\"\"\"Predicts the output of a layer using a trainable map.
|
| 1083 |
|
| 1084 |
Args:
|
| 1085 |
+
hidden_states: A 4D tensor of shape \`[num_altup_inputs, batch_size, num_tokens, hidden_size]\` derived by
|
| 1086 |
+
stacking the input embeddings and preprocessing the last \`num_altup_inputs - 1\` matrices.
|
| 1087 |
|
| 1088 |
Returns:
|
| 1089 |
+
A 4D tensor of shape \`[num_altup_inputs, batch_size, num_tokens, hidden_size]\` containing the predictions.
|
| 1090 |
+
\"\"\"
|
| 1091 |
modalities = self.compute_router_modalities(hidden_states[self.config.altup_active_idx])
|
| 1092 |
|
| 1093 |
if self.training and self.config.altup_coef_clip is not None:
|
|
|
|
| 1107 |
return predictions.contiguous().type_as(hidden_states)
|
| 1108 |
|
| 1109 |
def correct(self, predictions: torch.Tensor, activated: torch.Tensor) -> torch.Tensor:
|
| 1110 |
+
\"\"\"Corrects the predictions relative to the
|
| 1111 |
|
| 1112 |
Args:
|
| 1113 |
+
predictions: A 4D tensor of shape \`[num_altup_inputs, batch_size, num_tokens, hidden_size]\` derived by
|
| 1114 |
+
stacking the input embeddings and preprocessing the last \`num_altup_inputs - 1\` matrices.
|
| 1115 |
+
activated: A 3D tensor of shape \`[batch_size, num_tokens, hidden_size]\` containing the activated inputs.
|
| 1116 |
|
| 1117 |
Returns:
|
| 1118 |
+
A 4D tensor of shape \`[num_altup_inputs, batch_size, num_tokens, hidden_size]\` correcting the original
|
| 1119 |
predictions relative to the activated input embeddings.
|
| 1120 |
+
\"\"\"
|
| 1121 |
modalities = self.compute_router_modalities(activated)
|
| 1122 |
innovation = activated - predictions[self.config.altup_active_idx] # (batch, num_tokens, hidden_size)
|
| 1123 |
innovation = innovation.repeat(self.config.altup_num_inputs, 1, 1, 1) # Repeat on dim0 to match predictions
|
|
|
|
| 1125 |
if self.config.altup_coef_clip is not None:
|
| 1126 |
self.correction_coefs.weight.data.clamp_(-self.config.altup_coef_clip, self.config.altup_coef_clip)
|
| 1127 |
|
| 1128 |
+
# all_coefs adapted from jax.numpy.einsum(\"...p,pi->...i\", ...)
|
| 1129 |
# Permute to (altup_num_inputs, batch_size, num_tokens) as the last dim is a scalar applied to each altup input
|
| 1130 |
# and expand on dim1 for broadcastability
|
| 1131 |
all_coefs: torch.Tensor = self.correction_coefs(modalities) + 1.0
|
|
|
|
| 1136 |
return corrected.contiguous().type_as(activated)
|
| 1137 |
|
| 1138 |
def forward(self, corrected: torch.Tensor) -> torch.Tensor:
|
| 1139 |
+
\"\"\"
|
| 1140 |
+
This is only defined as the \`forward\` so that accelerate hooks can move correctly \`correct_output_scale\`
|
| 1141 |
(which is a nn.Parameter, not a Module) between devices when offloading. It is otherwise only used in
|
| 1142 |
+
\`scale_corrected_output\`
|
| 1143 |
+
\"\"\"
|
| 1144 |
return (corrected.type_as(self.correct_output_scale) * self.correct_output_scale).type_as(corrected)
|
| 1145 |
|
| 1146 |
def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor:
|
| 1147 |
+
\"\"\"Scales the provided 3D tensor of shape [batch_size, num_tokens, hidden_size].\"\"\"
|
| 1148 |
return self.forward(corrected)
|
| 1149 |
|
| 1150 |
|
| 1151 |
class Gemma3nTextRotaryEmbedding(nn.Module):
|
| 1152 |
def __init__(self, config: Gemma3nTextConfig, device=None):
|
| 1153 |
super().__init__()
|
| 1154 |
+
# BC: \"rope_type\" was originally \"type\"
|
| 1155 |
+
if hasattr(config, \"rope_scaling\") and config.rope_scaling is not None:
|
| 1156 |
+
self.rope_type = config.rope_scaling.get(\"rope_type\", config.rope_scaling.get(\"type\"))
|
| 1157 |
else:
|
| 1158 |
+
self.rope_type = \"default\"
|
| 1159 |
self.max_seq_len_cached = config.max_position_embeddings
|
| 1160 |
self.original_max_seq_len = config.max_position_embeddings
|
| 1161 |
|
|
|
|
| 1163 |
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 1164 |
|
| 1165 |
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
| 1166 |
+
self.register_buffer(\"inv_freq\", inv_freq, persistent=False)
|
| 1167 |
self.original_inv_freq = self.inv_freq
|
| 1168 |
|
| 1169 |
@torch.no_grad()
|
|
|
|
| 1172 |
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
| 1173 |
position_ids_expanded = position_ids[:, None, :].float()
|
| 1174 |
|
| 1175 |
+
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != \"mps\" else \"cpu\"
|
| 1176 |
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
| 1177 |
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 1178 |
emb = torch.cat((freqs, freqs), dim=-1)
|
|
|
|
| 1183 |
|
| 1184 |
|
| 1185 |
def rotate_half(x):
|
| 1186 |
+
\"\"\"Rotates half the hidden dims of the input.\"\"\"
|
| 1187 |
x1 = x[..., : x.shape[-1] // 2]
|
| 1188 |
x2 = x[..., x.shape[-1] // 2 :]
|
| 1189 |
return torch.cat((-x2, x1), dim=-1)
|
| 1190 |
|
| 1191 |
|
| 1192 |
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 1193 |
+
\"\"\"
|
| 1194 |
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 1195 |
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 1196 |
+
\"\"\"
|
| 1197 |
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 1198 |
if n_rep == 1:
|
| 1199 |
return hidden_states
|
|
|
|
| 1243 |
position_ids: Optional[torch.Tensor] = None,
|
| 1244 |
unsqueeze_dim: int = 1,
|
| 1245 |
):
|
| 1246 |
+
\"\"\"Applies Rotary Position Embedding to the query and key tensors.
|
| 1247 |
|
| 1248 |
Args:
|
| 1249 |
+
x (\`torch.Tensor\`): The tensor to embed.
|
| 1250 |
+
cos (\`torch.Tensor\`): The cosine part of the rotary embedding.
|
| 1251 |
+
sin (\`torch.Tensor\`): The sine part of the rotary embedding.
|
| 1252 |
+
position_ids (\`torch.Tensor\`, *optional*):
|
| 1253 |
Deprecated and unused.
|
| 1254 |
+
unsqueeze_dim (\`int\`, *optional*, defaults to 1):
|
| 1255 |
+
The \'unsqueeze_dim\' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 1256 |
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 1257 |
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 1258 |
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 1259 |
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 1260 |
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 1261 |
Returns:
|
| 1262 |
+
\`tuple(torch.Tensor)\` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 1263 |
+
\"\"\"
|
| 1264 |
cos = cos.unsqueeze(unsqueeze_dim)
|
| 1265 |
sin = sin.unsqueeze(unsqueeze_dim)
|
| 1266 |
return (x * cos) + (rotate_half(x) * sin)
|
| 1267 |
|
| 1268 |
|
| 1269 |
class Gemma3nTextAttention(nn.Module):
|
| 1270 |
+
\"\"\"Multi-headed attention from \'Attention Is All You Need\' paper\"\"\"
|
| 1271 |
|
| 1272 |
def __init__(self, config: Gemma3nTextConfig, layer_idx: int):
|
| 1273 |
super().__init__()
|
| 1274 |
+
self.is_sliding = config.layer_types[layer_idx] == \"sliding_attention\"
|
| 1275 |
self.config = config
|
| 1276 |
self.layer_idx = layer_idx
|
| 1277 |
+
self.head_dim = getattr(config, \"head_dim\", config.hidden_size // config.num_attention_heads)
|
| 1278 |
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
| 1279 |
self.attention_dropout = self.config.attention_dropout
|
| 1280 |
self.is_causal = True
|
|
|
|
| 1356 |
if past_key_value is not None:
|
| 1357 |
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 1358 |
cache_kwargs = {
|
| 1359 |
+
\"sin\": sin,
|
| 1360 |
+
\"cos\": cos,
|
| 1361 |
+
\"cache_position\": cache_position,
|
| 1362 |
+
\"sliding_window\": self.sliding_window,
|
| 1363 |
}
|
| 1364 |
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 1365 |
|
| 1366 |
attention_interface: Callable = eager_attention_forward
|
| 1367 |
+
if self.config._attn_implementation != \"eager\":
|
| 1368 |
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 1369 |
|
| 1370 |
attn_output, attn_weights = attention_interface(
|
|
|
|
| 1407 |
self.per_layer_projection = nn.Linear(self.hidden_size_per_layer_input, self.hidden_size, bias=False)
|
| 1408 |
self.post_per_layer_input_norm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
| 1409 |
|
| 1410 |
+
@deprecate_kwarg(\"last_cache_position\", version=\"4.53.0\")
|
| 1411 |
def forward(
|
| 1412 |
self,
|
| 1413 |
hidden_states: torch.Tensor,
|
|
|
|
| 1460 |
if self.config.altup_correct_scale:
|
| 1461 |
first_prediction = self.altup.scale_corrected_output(first_prediction)
|
| 1462 |
|
| 1463 |
+
# per_layer_input_gate adapted from jax.numpy.einsum(\"btd,dp->btp\", ...)
|
| 1464 |
first_prediction = self.per_layer_input_gate(first_prediction)
|
| 1465 |
first_prediction = self.act_fn(first_prediction)
|
| 1466 |
first_prediction = torch.multiply(first_prediction, per_layer_input)
|
| 1467 |
|
| 1468 |
+
# per_layer_projection adapted from jax.numpy.einsum(\"btp,pd->btd\", ...)
|
| 1469 |
first_prediction = self.per_layer_projection(first_prediction)
|
| 1470 |
first_prediction = self.post_per_layer_input_norm(first_prediction)
|
| 1471 |
corrected_predictions[1:] += first_prediction
|
|
|
|
| 1481 |
@auto_docstring
|
| 1482 |
class Gemma3nPreTrainedModel(PreTrainedModel):
|
| 1483 |
config_class = Gemma3nConfig
|
| 1484 |
+
base_model_prefix = \"\"
|
| 1485 |
supports_gradient_checkpointing = True
|
| 1486 |
+
_no_split_modules = [\"Gemma3nTextDecoderLayer\"]
|
| 1487 |
+
_skip_keys_device_placement = [\"past_key_values\"]
|
| 1488 |
_supports_flash_attn_3 = True
|
| 1489 |
_supports_flash_attn_2 = True
|
| 1490 |
_supports_sdpa = True
|
|
|
|
| 1495 |
_supports_attention_backend = True
|
| 1496 |
|
| 1497 |
def _init_weights(self, module):
|
| 1498 |
+
# important: this ported version of Gemma2 isn\'t meant for training from scratch - only
|
| 1499 |
# inference and fine-tuning - so the proper init weights code has been removed
|
| 1500 |
+
std = getattr(self.config, \"initializer_range\", self.config.get_text_config().initializer_range)
|
| 1501 |
|
| 1502 |
if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d)):
|
| 1503 |
module.weight.data.normal_(mean=0.0, std=std)
|
|
|
|
| 1518 |
module.correct_output_scale.data.zero_()
|
| 1519 |
|
| 1520 |
|
| 1521 |
+
@auto_docstring(custom_intro=\"The base Gemma 3n language model without a language modeling head.\")
|
| 1522 |
class Gemma3nTextModel(Gemma3nPreTrainedModel):
|
| 1523 |
config_class = Gemma3nTextConfig
|
| 1524 |
|
|
|
|
| 1544 |
# defaults should hold values for global RoPE.
|
| 1545 |
config = copy.deepcopy(config)
|
| 1546 |
config.rope_theta = config.rope_local_base_freq
|
| 1547 |
+
config.rope_scaling = {\"rope_type\": \"default\"}
|
| 1548 |
self.rotary_emb_local = Gemma3nTextRotaryEmbedding(config=config)
|
| 1549 |
|
| 1550 |
self.hidden_size = config.hidden_size
|
|
|
|
| 1573 |
[nn.Linear(self.hidden_size, self.hidden_size, bias=False) for _ in range(1, self.config.altup_num_inputs)]
|
| 1574 |
)
|
| 1575 |
|
| 1576 |
+
self.register_buffer(\"per_layer_projection_scale\", torch.tensor(self.hidden_size**-0.5), persistent=False)
|
| 1577 |
+
self.register_buffer(\"per_layer_input_scale\", torch.rsqrt(torch.tensor(2.0)), persistent=False)
|
| 1578 |
|
| 1579 |
# Initialize weights and apply final processing
|
| 1580 |
self.post_init()
|
|
|
|
| 1601 |
cache_position: Optional[torch.LongTensor] = None,
|
| 1602 |
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
| 1603 |
) -> BaseModelOutputWithPast:
|
| 1604 |
+
r\"\"\"
|
| 1605 |
per_layer_inputs (torch.Tensor, *optional*, defaults to None):
|
| 1606 |
Pre-computed per-layer embeddings. If None, they are derived from input_ids if provided.
|
| 1607 |
+
\"\"\"
|
| 1608 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1609 |
output_hidden_states = (
|
| 1610 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
|
|
| 1612 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 1613 |
|
| 1614 |
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 1615 |
+
raise ValueError(\"You must specify exactly one of input_ids or inputs_embeds\")
|
| 1616 |
|
| 1617 |
if self.gradient_checkpointing and self.training and use_cache:
|
| 1618 |
logger.warning_once(
|
| 1619 |
+
\"\`use_cache=True\` is incompatible with gradient checkpointing. Setting \`use_cache=False\`.\"
|
| 1620 |
)
|
| 1621 |
use_cache = False
|
| 1622 |
|
|
|
|
| 1640 |
if position_ids is None:
|
| 1641 |
position_ids = cache_position.unsqueeze(0)
|
| 1642 |
|
| 1643 |
+
# It may already have been prepared by e.g. \`generate\`
|
| 1644 |
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
| 1645 |
# Prepare mask arguments
|
| 1646 |
mask_kwargs = {
|
| 1647 |
+
\"config\": self.config,
|
| 1648 |
+
\"input_embeds\": inputs_embeds,
|
| 1649 |
+
\"attention_mask\": attention_mask,
|
| 1650 |
+
\"cache_position\": cache_position,
|
| 1651 |
+
\"past_key_values\": past_key_values,
|
| 1652 |
}
|
| 1653 |
# Create the masks
|
| 1654 |
causal_mask_mapping = {
|
| 1655 |
+
\"full_attention\": create_causal_mask(**mask_kwargs),
|
| 1656 |
+
\"sliding_attention\": create_sliding_window_causal_mask(**mask_kwargs),
|
| 1657 |
}
|
| 1658 |
|
| 1659 |
# embed positions
|
|
|
|
| 1669 |
|
| 1670 |
temp_hidden_states = [hidden_states_0]
|
| 1671 |
for i in range(1, self.config.altup_num_inputs):
|
| 1672 |
+
# altup_proj adapted from jax.numpy.einsum(\"btp,pd->btd\", ...)
|
| 1673 |
altup_proj = self.altup_projections[i - 1](hidden_states_0)
|
| 1674 |
current_hidden_state = altup_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device)
|
| 1675 |
new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True)
|
|
|
|
| 1717 |
target_magnitude = torch.mean(hidden_states[0] ** 2, dim=-1, keepdim=True) ** 0.5
|
| 1718 |
temp_hidden_states = [hidden_states[0]]
|
| 1719 |
for i in range(1, self.config.altup_num_inputs):
|
| 1720 |
+
# altup_unembed_projections adapted from jax.numpy.einsum(\"btp,pd->btd\", ...)
|
| 1721 |
altup_unemb_proj: torch.Tensor = self.altup_unembed_projections[i - 1](hidden_states[i])
|
| 1722 |
current_hidden_state = altup_unemb_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device)
|
| 1723 |
new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True)
|
|
|
|
| 1771 |
)
|
| 1772 |
|
| 1773 |
|
| 1774 |
+
@auto_docstring(custom_intro=\"The base Gemma 3n language model with a language modeling head.\")
|
| 1775 |
class Gemma3nForCausalLM(Gemma3nPreTrainedModel, GenerationMixin):
|
| 1776 |
+
_tied_weights_keys = [\"lm_head.weight\"]
|
| 1777 |
+
_tp_plan = {\"lm_head\": \"colwise_rep\"}
|
| 1778 |
+
_pp_plan = {\"lm_head\": ([\"hidden_states\"], [\"logits\"])}
|
| 1779 |
config_class = Gemma3nTextConfig
|
| 1780 |
+
base_model_prefix = \"model\"
|
| 1781 |
+
_checkpoint_conversion_mapping = {\"model.language_model\": \"model\"}
|
| 1782 |
|
| 1783 |
def __init__(self, config: Gemma3nTextConfig):
|
| 1784 |
super().__init__(config)
|
|
|
|
| 1824 |
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 1825 |
**loss_kwargs,
|
| 1826 |
) -> CausalLMOutputWithPast:
|
| 1827 |
+
r\"\"\"
|
| 1828 |
+
labels (\`torch.LongTensor\` of shape \`(batch_size, sequence_length)\`, *optional*):
|
| 1829 |
+
Labels for computing the masked language modeling loss. Indices should either be in \`[0, ...,
|
| 1830 |
+
config.vocab_size]\` or -100 (see \`input_ids\` docstring). Tokens with indices set to \`-100\` are ignored
|
| 1831 |
+
(masked), the loss is only computed for the tokens with labels in \`[0, ..., config.vocab_size]\`.
|
| 1832 |
|
| 1833 |
Example:
|
| 1834 |
|
| 1835 |
+
\`\`\`python
|
| 1836 |
>>> from transformers import AutoTokenizer, Gemma3nForCausalLM
|
| 1837 |
|
| 1838 |
+
>>> model = Gemma3nForCausalLM.from_pretrained(\"google/gemma-2-9b\")
|
| 1839 |
+
>>> tokenizer = AutoTokenizer.from_pretrained(\"google/gemma-2-9b\")
|
| 1840 |
|
| 1841 |
+
>>> prompt = \"What is your favorite condiment?\"
|
| 1842 |
+
>>> inputs = tokenizer(prompt, return_tensors=\"pt\")
|
| 1843 |
|
| 1844 |
>>> # Generate
|
| 1845 |
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
| 1846 |
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 1847 |
+
\"What is your favorite condiment?\"
|
| 1848 |
+
\`\`\`\"\"\"
|
| 1849 |
|
| 1850 |
+
if self.training and self.config._attn_implementation != \"eager\":
|
| 1851 |
logger.warning_once(
|
| 1852 |
+
\"It is strongly recommended to train Gemma3n models with the \`eager\` attention implementation \"
|
| 1853 |
+
f\"instead of \`{self.config._attn_implementation}\`. Use \`eager\` with \`AutoModelForCausalLM.from_pretrained(\'<path-to-checkpoint>\', attn_implementation=\'eager\')\`.\"
|
| 1854 |
)
|
| 1855 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1856 |
output_hidden_states = (
|
|
|
|
| 1893 |
|
| 1894 |
|
| 1895 |
class Gemma3nMultimodalEmbedder(nn.Module):
|
| 1896 |
+
\"\"\"Embeds token ids or soft tokens for multimodal content into language model space.\"\"\"
|
| 1897 |
|
| 1898 |
def __init__(
|
| 1899 |
self,
|
|
|
|
| 1919 |
input_ids: Optional[torch.LongTensor] = None,
|
| 1920 |
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1921 |
) -> torch.Tensor:
|
| 1922 |
+
\"\"\"Embeds token ids or soft tokens for multimodal content into language model space.
|
| 1923 |
|
| 1924 |
Args:
|
| 1925 |
input_ids: A torch.LongTensor containing the token ids to embed. Values should be in the range
|
| 1926 |
+
\`[vocab_offset, vocab_offset + vocab_size)\`.
|
| 1927 |
inputs_embeds: A torch.Tensor containing the soft tokens to embed.
|
| 1928 |
|
| 1929 |
Returns:
|
| 1930 |
+
A torch.Tensor of embeddings with shape \`[batch_size, seq_len, self.config.text_config.hidden_size]\`.
|
| 1931 |
+
\"\"\"
|
| 1932 |
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 1933 |
+
raise ValueError(\"You must specify exactly one of input_ids or inputs_embeds\")
|
| 1934 |
|
| 1935 |
if inputs_embeds is not None:
|
| 1936 |
emb_norm = self.soft_embedding_norm(inputs_embeds)
|
|
|
|
| 1943 |
|
| 1944 |
|
| 1945 |
@auto_docstring(
|
| 1946 |
+
custom_intro=\"\"\"
|
| 1947 |
The base Gemma 3n model comprising a vision backbone, an audio backbone, and a language model without a
|
| 1948 |
language modeling head.
|
| 1949 |
+
\"\"\"
|
| 1950 |
)
|
| 1951 |
class Gemma3nModel(Gemma3nPreTrainedModel):
|
| 1952 |
_checkpoint_conversion_mapping = {}
|
| 1953 |
+
# we are filtering the logits/labels so we shouldn\'t divide the loss based on num_items_in_batch
|
| 1954 |
accepts_loss_kwargs = False
|
| 1955 |
|
| 1956 |
def __init__(self, config: Gemma3nConfig):
|
|
|
|
| 1981 |
return self.language_model
|
| 1982 |
|
| 1983 |
def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
| 1984 |
+
\"\"\"
|
| 1985 |
Projects the last hidden state from the vision model into language model space.
|
| 1986 |
|
| 1987 |
Args:
|
| 1988 |
+
pixel_values (\`torch.FloatTensor]\` of shape \`(batch_size, channels, height, width)\`)
|
| 1989 |
The tensors corresponding to the input images.
|
| 1990 |
|
| 1991 |
Returns:
|
| 1992 |
+
image_features (\`torch.Tensor\`): Image feature tensor of shape \`(num_images, image_length, embed_dim)\`).
|
| 1993 |
+
\"\"\"
|
| 1994 |
vision_outputs = self.vision_tower(
|
| 1995 |
pixel_values=pixel_values, do_pooling=False, return_dict=True
|
| 1996 |
).last_hidden_state
|
|
|
|
| 2024 |
output_hidden_states: Optional[bool] = None,
|
| 2025 |
**lm_kwargs,
|
| 2026 |
) -> Gemma3nCausalLMOutputWithPast:
|
| 2027 |
+
r\"\"\"
|
| 2028 |
+
labels (\`torch.LongTensor\` of shape \`(batch_size, sequence_length)\`, *optional*):
|
| 2029 |
+
Labels for computing the masked language modeling loss. Indices should either be in \`[0, ...,
|
| 2030 |
+
config.text_config.vocab_size]\` or -100 (see \`input_ids\` docstring). Tokens with indices set to \`-100\` are ignored
|
| 2031 |
+
(masked), the loss is only computed for the tokens with labels in \`[0, ..., config.text_config.vocab_size]\`.
|
| 2032 |
|
| 2033 |
Example:
|
| 2034 |
|
| 2035 |
+
\`\`\`python
|
| 2036 |
>>> from PIL import Image
|
| 2037 |
>>> import requests
|
| 2038 |
>>> from transformers import AutoProcessor, Gemma3nForConditionalGeneration
|
| 2039 |
|
| 2040 |
+
>>> model = Gemma3nForConditionalGeneration.from_pretrained(\"google/gemma3n2-3b-mix-224\")
|
| 2041 |
+
>>> processor = AutoProcessor.from_pretrained(\"google/gemma3n2-3b-mix-224\")
|
| 2042 |
|
| 2043 |
+
>>> prompt = \"Where is the cat standing?\"
|
| 2044 |
+
>>> url = \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg\"
|
| 2045 |
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 2046 |
|
| 2047 |
+
>>> inputs = processor(images=image, text=prompt, return_tensors=\"pt\")
|
| 2048 |
|
| 2049 |
>>> # Generate
|
| 2050 |
>>> generate_ids = model.generate(**inputs,)
|
| 2051 |
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 2052 |
+
\"Where is the cat standing?\nsnow\"
|
| 2053 |
+
\`\`\`
|
| 2054 |
+
\"\"\"
|
| 2055 |
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 2056 |
+
raise ValueError(\"You must specify exactly one of input_ids or inputs_embeds\")
|
| 2057 |
|
| 2058 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 2059 |
output_hidden_states = (
|
|
|
|
| 2103 |
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
| 2104 |
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
| 2105 |
raise ValueError(
|
| 2106 |
+
f\"Number of images does not match number of special image tokens in the input text. \"
|
| 2107 |
+
f\"Got {image_tokens_in_text} image tokens in the text and \"
|
| 2108 |
+
f\"{image_features.shape[0] * image_features.shape[1]} tokens from image embeddings.\"
|
| 2109 |
)
|
| 2110 |
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
| 2111 |
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
|
|
|
| 2140 |
if not is_torchdynamo_compiling() and inputs_embeds[special_audio_mask].numel() != audio_features.numel():
|
| 2141 |
audio_tokens_in_text = (special_audio_mask).sum(dim=1).sum(dim=0)[0]
|
| 2142 |
raise ValueError(
|
| 2143 |
+
f\"Number of audio input features does not match number of special audio tokens in the input text. \"
|
| 2144 |
+
f\"Got {audio_tokens_in_text} audio tokens in the text and \"
|
| 2145 |
+
f\"{audio_features.shape[0] * audio_features.shape[1]} tokens from audio embeddings.\"
|
| 2146 |
)
|
| 2147 |
audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
| 2148 |
inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features)
|
|
|
|
| 2174 |
def get_audio_features(
|
| 2175 |
self, input_features: torch.Tensor, input_features_mask: torch.Tensor
|
| 2176 |
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 2177 |
+
\"\"\"
|
| 2178 |
Projects the last hidden state from the audio encoder into language model space.
|
| 2179 |
|
| 2180 |
Args:
|
| 2181 |
+
input_features (\`torch.FloatTensor]\` of shape \`(num_images, seq_length, num_features)\`):
|
| 2182 |
The tensors corresponding to the input audio.
|
| 2183 |
+
input_features (\`torch.FloatTensor]\` of shape \`(num_images, seq_length)\`):
|
| 2184 |
The attention mask for the input audio.
|
| 2185 |
|
| 2186 |
Returns:
|
| 2187 |
+
audio_features (\`torch.Tensor\`): Audio feature tensor of shape \`(num_images, audio_length, embed_dim)\`).
|
| 2188 |
+
\"\"\"
|
| 2189 |
audio_outputs, audio_mask = self.audio_tower(input_features, input_features_mask)
|
| 2190 |
return self.embed_audio(inputs_embeds=audio_outputs), audio_mask
|
| 2191 |
|
| 2192 |
|
| 2193 |
@auto_docstring(
|
| 2194 |
+
custom_intro=\"\"\"
|
| 2195 |
The base Gemma 3n model comprising a vision backbone, an audio backbone, a language model, and a language modeling
|
| 2196 |
head.
|
| 2197 |
+
\"\"\"
|
| 2198 |
)
|
| 2199 |
class Gemma3nForConditionalGeneration(Gemma3nPreTrainedModel, GenerationMixin):
|
| 2200 |
_checkpoint_conversion_mapping = {}
|
| 2201 |
+
_tied_weights_keys = [\"lm_head.weight\"]
|
| 2202 |
+
base_model_prefix = \"model\"
|
| 2203 |
|
| 2204 |
def __init__(self, config: Gemma3nConfig):
|
| 2205 |
super().__init__(config)
|
|
|
|
| 2239 |
|
| 2240 |
@property
|
| 2241 |
def multi_modal_projector(self):
|
| 2242 |
+
raise AttributeError(\"Use embed_vision instead of multi_modal_projector.\")
|
| 2243 |
|
| 2244 |
@can_return_tuple
|
| 2245 |
@auto_docstring
|
|
|
|
| 2262 |
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 2263 |
**lm_kwargs,
|
| 2264 |
) -> Gemma3nCausalLMOutputWithPast:
|
| 2265 |
+
r\"\"\"
|
| 2266 |
input_features (torch.Tensor, *optional*, defaults to None):
|
| 2267 |
The audio inputs to be encoded.
|
| 2268 |
input_features_mask (torch.Tensor, *optional*, defaults to None):
|
| 2269 |
The attention mask for the input audio.
|
| 2270 |
+
labels (\`torch.LongTensor\` of shape \`(batch_size, sequence_length)\`, *optional*):
|
| 2271 |
+
Labels for computing the masked language modeling loss. Indices should either be in \`[0, ...,
|
| 2272 |
+
config.text_config.vocab_size]\` or -100 (see \`input_ids\` docstring). Tokens with indices set to \`-100\` are
|
| 2273 |
ignored (masked), the loss is only computed for the tokens with labels in
|
| 2274 |
+
\`[0, ..., config.text_config.vocab_size]\`.
|
| 2275 |
|
| 2276 |
Example:
|
| 2277 |
|
| 2278 |
+
\`\`\`python
|
| 2279 |
>>> from PIL import Image
|
| 2280 |
>>> import requests
|
| 2281 |
>>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
|
| 2282 |
|
| 2283 |
+
>>> model = Gemma3ForConditionalGeneration.from_pretrained(\"google/gemma-3-4b-it\")
|
| 2284 |
+
>>> processor = AutoProcessor.from_pretrained(\"google/gemma-3-4b-it\")
|
| 2285 |
|
| 2286 |
>>> messages = [
|
| 2287 |
... {
|
| 2288 |
+
... \"role\": \"system\",
|
| 2289 |
+
... \"content\": [
|
| 2290 |
+
... {\"type\": \"text\", \"text\": \"You are a helpful assistant.\"}
|
| 2291 |
... ]
|
| 2292 |
... },
|
| 2293 |
... {
|
| 2294 |
+
... \"role\": \"user\", \"content\": [
|
| 2295 |
+
... {\"type\": \"image\", \"url\": \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg\"},
|
| 2296 |
+
... {\"type\": \"text\", \"text\": \"Where is the cat standing?\"},
|
| 2297 |
... ]
|
| 2298 |
... },
|
| 2299 |
... ]
|
|
|
|
| 2302 |
... messages,
|
| 2303 |
... tokenizer=True,
|
| 2304 |
... return_dict=True,
|
| 2305 |
+
... return_tensors=\"pt\",
|
| 2306 |
... add_generation_prompt=True
|
| 2307 |
... )
|
| 2308 |
>>> # Generate
|
| 2309 |
>>> generate_ids = model.generate(**inputs)
|
| 2310 |
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 2311 |
+
\"user\nYou are a helpful assistant.\n\n\n\n\n\nWhere is the cat standing?\nmodel\nBased on the image, the cat is standing in a snowy area, likely outdoors. It appears to\"
|
| 2312 |
+
\`\`\`
|
| 2313 |
+
\"\"\"
|
| 2314 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 2315 |
output_hidden_states = (
|
| 2316 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
|
|
| 2393 |
labels=None,
|
| 2394 |
**kwargs,
|
| 2395 |
):
|
| 2396 |
+
# Overwritten -- custom \`position_ids\` and \`pixel_values\` handling
|
| 2397 |
model_inputs = super().prepare_inputs_for_generation(
|
| 2398 |
input_ids,
|
| 2399 |
past_key_values=past_key_values,
|
|
|
|
| 2407 |
**kwargs,
|
| 2408 |
)
|
| 2409 |
|
| 2410 |
+
# If we\'re in cached decoding stage, multimodal inputs should be None because input ids do not contain special
|
| 2411 |
# tokens anymore. Otherwise multimodal inputs should be passed to model.
|
| 2412 |
# NOTE: use_cache=False always needs pixel_values, input_features, and input_features_mask
|
| 2413 |
if cache_position[0] == 0:
|
| 2414 |
+
model_inputs[\"pixel_values\"] = pixel_values
|
| 2415 |
+
model_inputs[\"input_features\"] = input_features
|
| 2416 |
+
model_inputs[\"input_features_mask\"] = input_features_mask
|
| 2417 |
|
| 2418 |
return model_inputs
|
| 2419 |
|
|
|
|
| 2423 |
|
| 2424 |
|
| 2425 |
__all__ = [
|
| 2426 |
+
\"Gemma3nAudioEncoder\",
|
| 2427 |
+
\"Gemma3nForCausalLM\",
|
| 2428 |
+
\"Gemma3nForConditionalGeneration\",
|
| 2429 |
+
\"Gemma3nModel\",
|
| 2430 |
+
\"Gemma3nPreTrainedModel\",
|
| 2431 |
+
\"Gemma3nTextModel\",
|
| 2432 |
]
|