Update modeling_minicpm.py
Browse files- modeling_minicpm.py +42 -21
modeling_minicpm.py
CHANGED
|
@@ -21,12 +21,16 @@
|
|
| 21 |
import math
|
| 22 |
import warnings
|
| 23 |
from typing import List, Optional, Tuple, Union, Dict
|
| 24 |
-
|
|
|
|
| 25 |
import torch
|
| 26 |
import torch.nn.functional as F
|
| 27 |
import torch.utils.checkpoint
|
| 28 |
from torch import nn
|
| 29 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
from transformers.activations import ACT2FN
|
| 32 |
from transformers.cache_utils import Cache, DynamicCache
|
|
@@ -35,6 +39,7 @@ from transformers.modeling_attn_mask_utils import (
|
|
| 35 |
_prepare_4d_attention_mask,
|
| 36 |
_prepare_4d_causal_attention_mask,
|
| 37 |
_prepare_4d_causal_attention_mask_for_sdpa,
|
|
|
|
| 38 |
)
|
| 39 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
| 40 |
from transformers.modeling_utils import PreTrainedModel
|
|
@@ -320,9 +325,6 @@ class MiniCPMAttention(nn.Module):
|
|
| 320 |
self.rope_theta = config.rope_theta
|
| 321 |
|
| 322 |
self.is_causal = config.is_causal
|
| 323 |
-
|
| 324 |
-
logger.info(f"self.is_causal = {self.is_causal}")
|
| 325 |
-
|
| 326 |
|
| 327 |
if (self.head_dim * self.num_heads) != self.hidden_size:
|
| 328 |
raise ValueError(
|
|
@@ -979,6 +981,8 @@ class MiniCPMModel(MiniCPMPreTrainedModel):
|
|
| 979 |
self.norm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 980 |
|
| 981 |
self.gradient_checkpointing = False
|
|
|
|
|
|
|
| 982 |
# Initialize weights and apply final processing
|
| 983 |
self.post_init()
|
| 984 |
|
|
@@ -1000,6 +1004,7 @@ class MiniCPMModel(MiniCPMPreTrainedModel):
|
|
| 1000 |
output_attentions: Optional[bool] = None,
|
| 1001 |
output_hidden_states: Optional[bool] = None,
|
| 1002 |
return_dict: Optional[bool] = None,
|
|
|
|
| 1003 |
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 1004 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1005 |
output_hidden_states = (
|
|
@@ -1044,24 +1049,35 @@ class MiniCPMModel(MiniCPMPreTrainedModel):
|
|
| 1044 |
inputs_embeds = self.embed_tokens(input_ids) * self.config.scale_emb
|
| 1045 |
|
| 1046 |
_attention_mask = attention_mask
|
| 1047 |
-
|
| 1048 |
if self._use_flash_attention_2:
|
| 1049 |
# 2d mask is passed through the layers
|
| 1050 |
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
| 1051 |
elif self._use_sdpa and not output_attentions:
|
| 1052 |
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
| 1053 |
# the manual implementation that requires a 4D causal mask in all cases.
|
| 1054 |
-
|
| 1055 |
-
attention_mask
|
| 1056 |
-
|
| 1057 |
-
|
| 1058 |
-
|
| 1059 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1060 |
else:
|
| 1061 |
# 4d mask is passed through the layers
|
| 1062 |
-
|
| 1063 |
-
attention_mask
|
| 1064 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1065 |
|
| 1066 |
# embed positions
|
| 1067 |
hidden_states = inputs_embeds
|
|
@@ -1109,14 +1125,18 @@ class MiniCPMModel(MiniCPMPreTrainedModel):
|
|
| 1109 |
if output_hidden_states:
|
| 1110 |
all_hidden_states += (hidden_states,)
|
| 1111 |
|
| 1112 |
-
|
| 1113 |
-
attention_mask_ = _attention_mask * _attention_mask.cumsum(dim=1)
|
| 1114 |
-
s = hidden_states * attention_mask_.unsqueeze(-1).float()
|
| 1115 |
-
d = attention_mask_.sum(dim=1, keepdim=True).unsqueeze(1).float() /_attention_mask.sum(dim=1, keepdim=True).unsqueeze(1).float()
|
| 1116 |
|
| 1117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1118 |
|
| 1119 |
-
next_cache = None
|
| 1120 |
if use_cache:
|
| 1121 |
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
| 1122 |
if not return_dict:
|
|
@@ -1127,7 +1147,8 @@ class MiniCPMModel(MiniCPMPreTrainedModel):
|
|
| 1127 |
hidden_states=all_hidden_states,
|
| 1128 |
attentions=all_self_attns,
|
| 1129 |
)
|
| 1130 |
-
|
|
|
|
| 1131 |
|
| 1132 |
class MiniCPMForCausalLM(MiniCPMPreTrainedModel):
|
| 1133 |
_tied_weights_keys = ["lm_head.weight"]
|
|
|
|
| 21 |
import math
|
| 22 |
import warnings
|
| 23 |
from typing import List, Optional, Tuple, Union, Dict
|
| 24 |
+
import os
|
| 25 |
+
from tqdm import tqdm
|
| 26 |
import torch
|
| 27 |
import torch.nn.functional as F
|
| 28 |
import torch.utils.checkpoint
|
| 29 |
from torch import nn
|
| 30 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 31 |
+
import numpy as np
|
| 32 |
+
from copy import deepcopy
|
| 33 |
+
from transformers import AutoTokenizer
|
| 34 |
|
| 35 |
from transformers.activations import ACT2FN
|
| 36 |
from transformers.cache_utils import Cache, DynamicCache
|
|
|
|
| 39 |
_prepare_4d_attention_mask,
|
| 40 |
_prepare_4d_causal_attention_mask,
|
| 41 |
_prepare_4d_causal_attention_mask_for_sdpa,
|
| 42 |
+
_prepare_4d_attention_mask_for_sdpa,
|
| 43 |
)
|
| 44 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
| 45 |
from transformers.modeling_utils import PreTrainedModel
|
|
|
|
| 325 |
self.rope_theta = config.rope_theta
|
| 326 |
|
| 327 |
self.is_causal = config.is_causal
|
|
|
|
|
|
|
|
|
|
| 328 |
|
| 329 |
if (self.head_dim * self.num_heads) != self.hidden_size:
|
| 330 |
raise ValueError(
|
|
|
|
| 981 |
self.norm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 982 |
|
| 983 |
self.gradient_checkpointing = False
|
| 984 |
+
self.is_causal = config.is_causal
|
| 985 |
+
self.adapt_mean_pooling = config.adapt_mean_pooling
|
| 986 |
# Initialize weights and apply final processing
|
| 987 |
self.post_init()
|
| 988 |
|
|
|
|
| 1004 |
output_attentions: Optional[bool] = None,
|
| 1005 |
output_hidden_states: Optional[bool] = None,
|
| 1006 |
return_dict: Optional[bool] = None,
|
| 1007 |
+
adapt_mean_pooling: Optional[bool] = None,
|
| 1008 |
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 1009 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1010 |
output_hidden_states = (
|
|
|
|
| 1049 |
inputs_embeds = self.embed_tokens(input_ids) * self.config.scale_emb
|
| 1050 |
|
| 1051 |
_attention_mask = attention_mask
|
|
|
|
| 1052 |
if self._use_flash_attention_2:
|
| 1053 |
# 2d mask is passed through the layers
|
| 1054 |
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
| 1055 |
elif self._use_sdpa and not output_attentions:
|
| 1056 |
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
| 1057 |
# the manual implementation that requires a 4D causal mask in all cases.
|
| 1058 |
+
if self.is_causal:
|
| 1059 |
+
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa (
|
| 1060 |
+
attention_mask,
|
| 1061 |
+
(batch_size, seq_length),
|
| 1062 |
+
inputs_embeds,
|
| 1063 |
+
past_key_values_length,
|
| 1064 |
+
)
|
| 1065 |
+
else:
|
| 1066 |
+
attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
| 1067 |
+
attention_mask,
|
| 1068 |
+
inputs_embeds.dtype,
|
| 1069 |
+
)
|
| 1070 |
else:
|
| 1071 |
# 4d mask is passed through the layers
|
| 1072 |
+
if self.is_causal:
|
| 1073 |
+
attention_mask = _prepare_4d_causal_attention_mask (
|
| 1074 |
+
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
| 1075 |
+
)
|
| 1076 |
+
else:
|
| 1077 |
+
attention_mask = _prepare_4d_attention_mask(
|
| 1078 |
+
attention_mask,
|
| 1079 |
+
inputs_embeds.dtype,
|
| 1080 |
+
)
|
| 1081 |
|
| 1082 |
# embed positions
|
| 1083 |
hidden_states = inputs_embeds
|
|
|
|
| 1125 |
if output_hidden_states:
|
| 1126 |
all_hidden_states += (hidden_states,)
|
| 1127 |
|
| 1128 |
+
next_cache = None
|
|
|
|
|
|
|
|
|
|
| 1129 |
|
| 1130 |
+
# gen weight before mean pooling
|
| 1131 |
+
if adapt_mean_pooling is None:
|
| 1132 |
+
adapt_mean_pooling = self.adapt_mean_pooling
|
| 1133 |
+
if adapt_mean_pooling:
|
| 1134 |
+
attention_mask_ = _attention_mask * _attention_mask.cumsum(dim=1)
|
| 1135 |
+
s = hidden_states * attention_mask_.unsqueeze(-1).float()
|
| 1136 |
+
d = attention_mask_.sum(dim=1, keepdim=True).unsqueeze(1).float() /_attention_mask.sum(dim=1, keepdim=True).unsqueeze(1).float()
|
| 1137 |
+
|
| 1138 |
+
hidden_states = s / d
|
| 1139 |
|
|
|
|
| 1140 |
if use_cache:
|
| 1141 |
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
| 1142 |
if not return_dict:
|
|
|
|
| 1147 |
hidden_states=all_hidden_states,
|
| 1148 |
attentions=all_self_attns,
|
| 1149 |
)
|
| 1150 |
+
|
| 1151 |
+
|
| 1152 |
|
| 1153 |
class MiniCPMForCausalLM(MiniCPMPreTrainedModel):
|
| 1154 |
_tied_weights_keys = ["lm_head.weight"]
|