Spaces:
Sleeping
Sleeping
| __all__ = [ | |
| "xLSTMConfig", | |
| "xLSTMLMHeadModel", | |
| ] | |
| import json | |
| import os | |
| from collections import namedtuple | |
| from dataclasses import asdict | |
| import torch | |
| import torch.nn as nn | |
| from dacite import Config as DaciteConfig, from_dict | |
| from omegaconf import OmegaConf | |
| from transformers import PretrainedConfig | |
| from protxlstm.generation import GenerationMixinSafe | |
| from protxlstm.utils import load_config_hf, load_state_dict_hf | |
| from protxlstm.xlstm.xlstm_lm_model import xLSTMLMModel, xLSTMLMModelConfig | |
| class xLSTMConfig(PretrainedConfig): | |
| def __init__(self): | |
| self.config_dataclass = xLSTMLMModelConfig() | |
| def init_from_dict(self, config: dict): | |
| config = OmegaConf.create(config) | |
| self.config_dataclass = from_dict( | |
| data_class=xLSTMLMModelConfig, | |
| data=OmegaConf.to_container(config), | |
| config=DaciteConfig(strict=True), | |
| ) | |
| return self | |
| def to_dict(self): | |
| return asdict(self.config_dataclass) | |
| class xLSTMLMHeadModel(nn.Module, GenerationMixinSafe): | |
| def __init__(self, config: xLSTMConfig) -> None: | |
| super().__init__() | |
| self.config = config | |
| self.backbone = xLSTMLMModel(self.config.config_dataclass) | |
| self.backbone.reset_parameters() | |
| self.setup() | |
| def setup(self): | |
| if 'LOCAL_RANK' in os.environ: | |
| current_device = int(os.environ['LOCAL_RANK']) | |
| else: | |
| if 'SLURM_LOCALID' in os.environ: | |
| current_device = int(os.environ['SLURM_LOCALID']) | |
| else: | |
| current_device = 0 | |
| #torch.cuda.set_device(f'cuda:{current_device}') | |
| #self.backbone = self.backbone.to("cuda") | |
| def forward( | |
| self, | |
| input_ids, | |
| state=None, | |
| position_ids=None, | |
| seq_position_ids=None, | |
| inference_params=None, | |
| num_last_tokens=0, | |
| save_layer=[], | |
| **kwargs, | |
| ): | |
| if self.config.config_dataclass.mlstm_block.mlstm.return_last_state: | |
| lm_logits, state = self.backbone(input_ids, position_ids=position_ids, seq_position_ids=seq_position_ids, state=state) | |
| CausalLMOutput = namedtuple("CausalLMOutput", ["loss", "logits", "state"]) | |
| return CausalLMOutput(loss=None, logits=lm_logits, state=state) | |
| else: | |
| lm_logits = self.backbone(input_ids, position_ids=position_ids, seq_position_ids=seq_position_ids, state=state) | |
| CausalLMOutput = namedtuple("CausalLMOutput", ["loss", "logits"]) | |
| return CausalLMOutput(loss=None, logits=lm_logits) | |
| def step( | |
| self, | |
| input_ids, | |
| state=None, | |
| position_ids=None, | |
| seq_position_ids=None, | |
| inference_params=None, | |
| num_last_tokens=0, | |
| save_layer=[], | |
| **kwargs, | |
| ): | |
| lm_logits, state = self.backbone.step( | |
| input_ids, state=state, position_ids=position_ids, seq_position_ids=seq_position_ids | |
| ) | |
| return lm_logits, state | |
| def from_pretrained( | |
| cls, | |
| pretrained_model_name, | |
| device=None, | |
| dtype=None, | |
| mlstm_backend=None, | |
| mlstm_chunksize=None, | |
| checkpoint_blocks=None, | |
| rope_base_frequency=None, | |
| mlstm_return_last_state=None, | |
| ): | |
| # Load the checkpoint config | |
| config_dict = load_config_hf(pretrained_model_name) | |
| # update rope base frequency | |
| if rope_base_frequency is not None and config_dict.get("rope_base_frequency", None) != rope_base_frequency: | |
| config_dict["rope_base_frequency"] = rope_base_frequency | |
| # update mlstm backend | |
| if mlstm_backend is not None and config_dict["mlstm_block"]["mlstm"].get("backend", None) != mlstm_backend: | |
| assert mlstm_backend in ["chunkwise", "chunkwise_variable", "parallel"], "invalid mlstm backend." | |
| config_dict["mlstm_block"]["mlstm"]["backend"] = mlstm_backend | |
| # update mlstm chunksize | |
| if mlstm_chunksize is not None and config_dict["mlstm_block"]["mlstm"].get("chunk_size", None) != mlstm_chunksize: | |
| config_dict["mlstm_block"]["mlstm"]["chunk_size"] = mlstm_chunksize | |
| # update activation checkpointing | |
| if checkpoint_blocks is not None: | |
| config_dict["checkpoint_blocks"] = checkpoint_blocks | |
| if mlstm_return_last_state is not None: | |
| config_dict["mlstm_block"]["mlstm"]["return_last_state"] = mlstm_return_last_state | |
| if "slstm_block" in config_dict: | |
| config_dict.pop("slstm_block") | |
| if "slstm_at" in config_dict: | |
| config_dict.pop("slstm_at") | |
| config = xLSTMConfig().init_from_dict(config_dict) | |
| model = cls(config) | |
| state_dict = load_state_dict_hf( | |
| pretrained_model_name, device=device, dtype=dtype | |
| ) | |
| assert ( | |
| state_dict.keys() == model.state_dict().keys() | |
| ), "The keys of the state_dict do not match the model's keys." | |
| model.load_state_dict(state_dict) | |
| return model | |
| def save_pretrained(self, save_directory): | |
| """ | |
| Save the model and its configuration file to a directory. | |
| """ | |
| # Ensure save_directory exists | |
| os.makedirs(save_directory, exist_ok=True) | |
| # Save the model's state_dict | |
| model_path = os.path.join(save_directory, "pytorch_model.bin") | |
| torch.save(self.state_dict(), model_path) | |
| # Save the configuration of the model | |
| config_path = os.path.join(save_directory, "config.json") | |
| with open(config_path, "w") as f: | |
| json.dump(self.config.to_dict(), f) | |