Spaces:
Runtime error
Runtime error
| """ | |
| Here is an example of using batch request glm-4-9b, | |
| here you need to build the conversation format yourself and then call the batch function to make batch requests. | |
| Please note that in this demo, the memory consumption is significantly higher. | |
| """ | |
| from typing import Optional, Union | |
| from transformers import AutoModel, AutoTokenizer, LogitsProcessorList | |
| MODEL_PATH = 'THUDM/glm-4-9b-chat' | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| MODEL_PATH, | |
| trust_remote_code=True, | |
| encode_special_tokens=True) | |
| model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, device_map="auto").eval() | |
| def process_model_outputs(inputs, outputs, tokenizer): | |
| responses = [] | |
| for input_ids, output_ids in zip(inputs.input_ids, outputs): | |
| response = tokenizer.decode(output_ids[len(input_ids):], skip_special_tokens=True).strip() | |
| responses.append(response) | |
| return responses | |
| def batch( | |
| model, | |
| tokenizer, | |
| messages: Union[str, list[str]], | |
| max_input_tokens: int = 8192, | |
| max_new_tokens: int = 8192, | |
| num_beams: int = 1, | |
| do_sample: bool = True, | |
| top_p: float = 0.8, | |
| temperature: float = 0.8, | |
| logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(), | |
| ): | |
| messages = [messages] if isinstance(messages, str) else messages | |
| batched_inputs = tokenizer(messages, return_tensors="pt", padding="max_length", truncation=True, | |
| max_length=max_input_tokens).to(model.device) | |
| gen_kwargs = { | |
| "max_new_tokens": max_new_tokens, | |
| "num_beams": num_beams, | |
| "do_sample": do_sample, | |
| "top_p": top_p, | |
| "temperature": temperature, | |
| "logits_processor": logits_processor, | |
| "eos_token_id": model.config.eos_token_id | |
| } | |
| batched_outputs = model.generate(**batched_inputs, **gen_kwargs) | |
| batched_response = process_model_outputs(batched_inputs, batched_outputs, tokenizer) | |
| return batched_response | |
| if __name__ == "__main__": | |
| batch_message = [ | |
| [ | |
| {"role": "user", "content": "我的爸爸和妈妈结婚为什么不能带我去"}, | |
| {"role": "assistant", "content": "因为他们结婚时你还没有出生"}, | |
| {"role": "user", "content": "我刚才的提问是"} | |
| ], | |
| [ | |
| {"role": "user", "content": "你好,你是谁"} | |
| ] | |
| ] | |
| batch_inputs = [] | |
| max_input_tokens = 1024 | |
| for i, messages in enumerate(batch_message): | |
| new_batch_input = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) | |
| max_input_tokens = max(max_input_tokens, len(new_batch_input)) | |
| batch_inputs.append(new_batch_input) | |
| gen_kwargs = { | |
| "max_input_tokens": max_input_tokens, | |
| "max_new_tokens": 8192, | |
| "do_sample": True, | |
| "top_p": 0.8, | |
| "temperature": 0.8, | |
| "num_beams": 1, | |
| } | |
| batch_responses = batch(model, tokenizer, batch_inputs, **gen_kwargs) | |
| for response in batch_responses: | |
| print("=" * 10) | |
| print(response) | |