Update code
Browse files- modeling_minicpm.py +32 -5
modeling_minicpm.py
CHANGED
|
@@ -20,7 +20,7 @@
|
|
| 20 |
""" PyTorch MiniCPM model."""
|
| 21 |
import math
|
| 22 |
import warnings
|
| 23 |
-
from typing import List, Optional, Tuple, Union
|
| 24 |
|
| 25 |
import torch
|
| 26 |
import torch.nn.functional as F
|
|
@@ -49,11 +49,13 @@ from transformers.utils import (
|
|
| 49 |
)
|
| 50 |
from transformers.utils.import_utils import is_torch_fx_available
|
| 51 |
from .configuration_minicpm import MiniCPMConfig
|
|
|
|
| 52 |
|
| 53 |
-
|
| 54 |
-
if is_flash_attn_2_available():
|
| 55 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 56 |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
|
|
|
|
|
|
| 57 |
|
| 58 |
|
| 59 |
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
|
|
@@ -124,7 +126,7 @@ ALL_LAYERNORM_LAYERS.append(MiniCPMRMSNorm)
|
|
| 124 |
|
| 125 |
|
| 126 |
class MiniCPMRotaryEmbedding(nn.Module):
|
| 127 |
-
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=
|
| 128 |
super().__init__()
|
| 129 |
|
| 130 |
self.dim = dim
|
|
@@ -762,7 +764,6 @@ class MiniCPMDecoderLayer(nn.Module):
|
|
| 762 |
def __init__(self, config: MiniCPMConfig, layer_idx: int):
|
| 763 |
super().__init__()
|
| 764 |
self.hidden_size = config.hidden_size
|
| 765 |
-
|
| 766 |
self.self_attn = MINICPM_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
|
| 767 |
|
| 768 |
self.mlp = MiniCPMMLP(config)
|
|
@@ -1302,6 +1303,32 @@ class MiniCPMForCausalLM(MiniCPMPreTrainedModel):
|
|
| 1302 |
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
| 1303 |
)
|
| 1304 |
return reordered_past
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1305 |
|
| 1306 |
|
| 1307 |
@add_start_docstrings(
|
|
|
|
| 20 |
""" PyTorch MiniCPM model."""
|
| 21 |
import math
|
| 22 |
import warnings
|
| 23 |
+
from typing import List, Optional, Tuple, Union, Dict
|
| 24 |
|
| 25 |
import torch
|
| 26 |
import torch.nn.functional as F
|
|
|
|
| 49 |
)
|
| 50 |
from transformers.utils.import_utils import is_torch_fx_available
|
| 51 |
from .configuration_minicpm import MiniCPMConfig
|
| 52 |
+
import re
|
| 53 |
|
| 54 |
+
try:
|
|
|
|
| 55 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 56 |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
| 57 |
+
except:
|
| 58 |
+
pass
|
| 59 |
|
| 60 |
|
| 61 |
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
|
|
|
|
| 126 |
|
| 127 |
|
| 128 |
class MiniCPMRotaryEmbedding(nn.Module):
|
| 129 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
| 130 |
super().__init__()
|
| 131 |
|
| 132 |
self.dim = dim
|
|
|
|
| 764 |
def __init__(self, config: MiniCPMConfig, layer_idx: int):
|
| 765 |
super().__init__()
|
| 766 |
self.hidden_size = config.hidden_size
|
|
|
|
| 767 |
self.self_attn = MINICPM_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
|
| 768 |
|
| 769 |
self.mlp = MiniCPMMLP(config)
|
|
|
|
| 1303 |
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
| 1304 |
)
|
| 1305 |
return reordered_past
|
| 1306 |
+
|
| 1307 |
+
@torch.inference_mode()
|
| 1308 |
+
def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
|
| 1309 |
+
max_length: int = 4096, num_beams=1, do_sample=True, top_p=0.8, temperature=0.3, logits_processor=None,
|
| 1310 |
+
**kwargs):
|
| 1311 |
+
if history is None:
|
| 1312 |
+
history = []
|
| 1313 |
+
if logits_processor:
|
| 1314 |
+
gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
|
| 1315 |
+
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
|
| 1316 |
+
else:
|
| 1317 |
+
gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
|
| 1318 |
+
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
|
| 1319 |
+
|
| 1320 |
+
history.append({"role": role, "content": query})
|
| 1321 |
+
history_str = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=False)
|
| 1322 |
+
inputs = tokenizer(history_str, return_tensors='pt').to(self.device)
|
| 1323 |
+
outputs = self.generate(**inputs, **gen_kwargs)
|
| 1324 |
+
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
|
| 1325 |
+
response = tokenizer.decode(outputs)
|
| 1326 |
+
pattern = re.compile(r".*?(?=<AI>|<用户>)", re.DOTALL)
|
| 1327 |
+
matches = pattern.findall(response)
|
| 1328 |
+
if len(matches) > 0:
|
| 1329 |
+
response = matches[0]
|
| 1330 |
+
history.append({"role": "assistant", "content": response})
|
| 1331 |
+
return response, history
|
| 1332 |
|
| 1333 |
|
| 1334 |
@add_start_docstrings(
|