| from typing import Optional, Tuple | |
| from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel | |
| import torch | |
| from transformers import GenerationMixin, PreTrainedModel | |
| from transformers.generation import TextStreamer | |
| from .configuration_mamba import MambaConfig | |
| class MambaModel(PreTrainedModel): | |
| config_class = MambaConfig | |
| def __init__( | |
| self, | |
| config, | |
| initializer_cfg=None, | |
| device=None, | |
| dtype=None, | |
| **kwargs, | |
| ): | |
| super().__init__( | |
| config, | |
| **kwargs, | |
| ) | |
| self.model = MambaLMHeadModel( | |
| config, | |
| initializer_cfg=initializer_cfg, | |
| device=device, | |
| dtype=dtype, | |
| ) | |
| def forward( | |
| self, | |
| input_ids, | |
| position_ids=None, | |
| inference_params=None, | |
| num_last_tokens=0, | |
| **kwargs, | |
| ): | |
| return self.model.forward( | |
| input_ids, | |
| position_ids, | |
| inference_params, | |
| num_last_tokens | |
| ) | |
| class MambaModelForCausalLM(MambaModel, GenerationMixin): | |
| def generate( | |
| self, | |
| input_ids, | |
| max_length: int = 2048, | |
| top_k: int = 1, | |
| top_p: float = 0.0, | |
| temperature: float = 1.0, | |
| return_dict_in_generate: bool = False, | |
| output_scores: bool = False, | |
| repetition_penalty: float = 1.0, | |
| eos_token_id: Optional[int] = None, | |
| teacher_outputs: Optional[torch.Tensor] = None, | |
| vocab_size: Optional[int] = None, | |
| cg: bool = False, | |
| enable_timing: bool = False, | |
| streamer: Optional[TextStreamer] = None, | |
| **kwargs, | |
| ): | |
| return self.model.generate( | |
| input_ids=input_ids, | |
| max_length=max_length, | |
| top_k=top_k, | |
| top_p=top_p, | |
| temperature=temperature, | |
| return_dict_in_generate=return_dict_in_generate, | |
| output_scores=output_scores, | |
| repetition_penalty=repetition_penalty, | |
| eos_token_id=eos_token_id, | |
| teacher_outputs=teacher_outputs, | |
| vocab_size=vocab_size, | |
| cg=cg, | |
| enable_timing=enable_timing, | |
| streamer=streamer, | |
| ) | |