Spaces:
Runtime error
Runtime error
| """ | |
| This script creates a CLI demo with vllm backand for the glm-4-9b model, | |
| allowing users to interact with the model through a command-line interface. | |
| Usage: | |
| - Run the script to start the CLI demo. | |
| - Interact with the model by typing questions and receiving responses. | |
| Note: The script includes a modification to handle markdown to plain text conversion, | |
| ensuring that the CLI interface displays formatted text correctly. | |
| """ | |
| import time | |
| import asyncio | |
| from transformers import AutoTokenizer | |
| from vllm import SamplingParams, AsyncEngineArgs, AsyncLLMEngine | |
| from typing import List, Dict | |
| MODEL_PATH = 'THUDM/glm-4-9b' | |
| def load_model_and_tokenizer(model_dir: str): | |
| engine_args = AsyncEngineArgs( | |
| model=model_dir, | |
| tokenizer=model_dir, | |
| tensor_parallel_size=1, | |
| dtype="bfloat16", | |
| trust_remote_code=True, | |
| gpu_memory_utilization=0.3, | |
| enforce_eager=True, | |
| worker_use_ray=True, | |
| engine_use_ray=False, | |
| disable_log_requests=True | |
| # 如果遇见 OOM 现象,建议开启下述参数 | |
| # enable_chunked_prefill=True, | |
| # max_num_batched_tokens=8192 | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_dir, | |
| trust_remote_code=True, | |
| encode_special_tokens=True | |
| ) | |
| engine = AsyncLLMEngine.from_engine_args(engine_args) | |
| return engine, tokenizer | |
| engine, tokenizer = load_model_and_tokenizer(MODEL_PATH) | |
| async def vllm_gen(messages: List[Dict[str, str]], top_p: float, temperature: float, max_dec_len: int): | |
| inputs = tokenizer.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=False | |
| ) | |
| params_dict = { | |
| "n": 1, | |
| "best_of": 1, | |
| "presence_penalty": 1.0, | |
| "frequency_penalty": 0.0, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "top_k": -1, | |
| "use_beam_search": False, | |
| "length_penalty": 1, | |
| "early_stopping": False, | |
| "stop_token_ids": [151329, 151336, 151338], | |
| "ignore_eos": False, | |
| "max_tokens": max_dec_len, | |
| "logprobs": None, | |
| "prompt_logprobs": None, | |
| "skip_special_tokens": True, | |
| } | |
| sampling_params = SamplingParams(**params_dict) | |
| async for output in engine.generate(inputs=inputs, sampling_params=sampling_params, request_id=f"{time.time()}"): | |
| yield output.outputs[0].text | |
| async def chat(): | |
| history = [] | |
| max_length = 8192 | |
| top_p = 0.8 | |
| temperature = 0.6 | |
| print("Welcome to the GLM-4-9B CLI chat. Type your messages below.") | |
| while True: | |
| user_input = input("\nYou: ") | |
| if user_input.lower() in ["exit", "quit"]: | |
| break | |
| history.append([user_input, ""]) | |
| messages = [] | |
| for idx, (user_msg, model_msg) in enumerate(history): | |
| if idx == len(history) - 1 and not model_msg: | |
| messages.append({"role": "user", "content": user_msg}) | |
| break | |
| if user_msg: | |
| messages.append({"role": "user", "content": user_msg}) | |
| if model_msg: | |
| messages.append({"role": "assistant", "content": model_msg}) | |
| print("\nGLM-4: ", end="") | |
| current_length = 0 | |
| output = "" | |
| async for output in vllm_gen(messages, top_p, temperature, max_length): | |
| print(output[current_length:], end="", flush=True) | |
| current_length = len(output) | |
| history[-1][1] = output | |
| if __name__ == "__main__": | |
| asyncio.run(chat()) | |