Spaces:
Runtime error
Runtime error
| from dataclasses import dataclass | |
| import sys | |
| class XftConfig: | |
| max_seq_len: int = 4096 | |
| beam_width: int = 1 | |
| eos_token_id: int = -1 | |
| pad_token_id: int = -1 | |
| num_return_sequences: int = 1 | |
| is_encoder_decoder: bool = False | |
| padding: bool = True | |
| early_stopping: bool = False | |
| data_type: str = "bf16_fp16" | |
| class XftModel: | |
| def __init__(self, xft_model, xft_config): | |
| self.model = xft_model | |
| self.config = xft_config | |
| def load_xft_model(model_path, xft_config: XftConfig): | |
| try: | |
| import xfastertransformer | |
| from transformers import AutoTokenizer | |
| except ImportError as e: | |
| print(f"Error: Failed to load xFasterTransformer. {e}") | |
| sys.exit(-1) | |
| if xft_config.data_type is None or xft_config.data_type == "": | |
| data_type = "bf16_fp16" | |
| else: | |
| data_type = xft_config.data_type | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_path, use_fast=False, padding_side="left", trust_remote_code=True | |
| ) | |
| xft_model = xfastertransformer.AutoModel.from_pretrained( | |
| model_path, dtype=data_type | |
| ) | |
| model = XftModel(xft_model=xft_model, xft_config=xft_config) | |
| if model.model.rank > 0: | |
| while True: | |
| model.model.generate() | |
| return model, tokenizer | |