Spaces:
Runtime error
Runtime error
| from dataclasses import dataclass, field | |
| import sys | |
| class ExllamaConfig: | |
| max_seq_len: int | |
| gpu_split: str = None | |
| cache_8bit: bool = False | |
| class ExllamaModel: | |
| def __init__(self, exllama_model, exllama_cache): | |
| self.model = exllama_model | |
| self.cache = exllama_cache | |
| self.config = self.model.config | |
| def load_exllama_model(model_path, exllama_config: ExllamaConfig): | |
| try: | |
| from exllamav2 import ( | |
| ExLlamaV2Config, | |
| ExLlamaV2Tokenizer, | |
| ExLlamaV2, | |
| ExLlamaV2Cache, | |
| ExLlamaV2Cache_8bit, | |
| ) | |
| except ImportError as e: | |
| print(f"Error: Failed to load Exllamav2. {e}") | |
| sys.exit(-1) | |
| exllamav2_config = ExLlamaV2Config() | |
| exllamav2_config.model_dir = model_path | |
| exllamav2_config.prepare() | |
| exllamav2_config.max_seq_len = exllama_config.max_seq_len | |
| exllamav2_config.cache_8bit = exllama_config.cache_8bit | |
| exllama_model = ExLlamaV2(exllamav2_config) | |
| tokenizer = ExLlamaV2Tokenizer(exllamav2_config) | |
| split = None | |
| if exllama_config.gpu_split: | |
| split = [float(alloc) for alloc in exllama_config.gpu_split.split(",")] | |
| exllama_model.load(split) | |
| cache_class = ExLlamaV2Cache_8bit if exllamav2_config.cache_8bit else ExLlamaV2Cache | |
| exllama_cache = cache_class(exllama_model) | |
| model = ExllamaModel(exllama_model=exllama_model, exllama_cache=exllama_cache) | |
| return model, tokenizer | |