Spaces:
Runtime error
Runtime error
| """ | |
| This script creates a CLI demo with transformers backend for the glm-4v-9b model, | |
| allowing users to interact with the model through a command-line interface. | |
| Usage: | |
| - Run the script to start the CLI demo. | |
| - Interact with the model by typing questions and receiving responses. | |
| Note: The script includes a modification to handle markdown to plain text conversion, | |
| ensuring that the CLI interface displays formatted text correctly. | |
| """ | |
| import os | |
| import torch | |
| from threading import Thread | |
| from transformers import ( | |
| AutoTokenizer, | |
| StoppingCriteria, | |
| StoppingCriteriaList, | |
| TextIteratorStreamer, AutoModel, BitsAndBytesConfig | |
| ) | |
| from PIL import Image | |
| MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4v-9b') | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| MODEL_PATH, | |
| trust_remote_code=True, | |
| encode_special_tokens=True | |
| ) | |
| model = AutoModel.from_pretrained( | |
| MODEL_PATH, | |
| trust_remote_code=True, | |
| device_map="auto", | |
| torch_dtype=torch.bfloat16 | |
| ).eval() | |
| ## For INT4 inference | |
| # model = AutoModel.from_pretrained( | |
| # MODEL_PATH, | |
| # trust_remote_code=True, | |
| # quantization_config=BitsAndBytesConfig(load_in_4bit=True), | |
| # torch_dtype=torch.bfloat16, | |
| # low_cpu_mem_usage=True | |
| # ).eval() | |
| class StopOnTokens(StoppingCriteria): | |
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
| stop_ids = model.config.eos_token_id | |
| for stop_id in stop_ids: | |
| if input_ids[0][-1] == stop_id: | |
| return True | |
| return False | |
| if __name__ == "__main__": | |
| history = [] | |
| max_length = 1024 | |
| top_p = 0.8 | |
| temperature = 0.6 | |
| stop = StopOnTokens() | |
| uploaded = False | |
| image = None | |
| print("Welcome to the GLM-4-9B CLI chat. Type your messages below.") | |
| image_path = input("Image Path:") | |
| try: | |
| image = Image.open(image_path).convert("RGB") | |
| except: | |
| print("Invalid image path. Continuing with text conversation.") | |
| while True: | |
| user_input = input("\nYou: ") | |
| if user_input.lower() in ["exit", "quit"]: | |
| break | |
| history.append([user_input, ""]) | |
| messages = [] | |
| for idx, (user_msg, model_msg) in enumerate(history): | |
| if idx == len(history) - 1 and not model_msg: | |
| messages.append({"role": "user", "content": user_msg}) | |
| if image and not uploaded: | |
| messages[-1].update({"image": image}) | |
| uploaded = True | |
| break | |
| if user_msg: | |
| messages.append({"role": "user", "content": user_msg}) | |
| if model_msg: | |
| messages.append({"role": "assistant", "content": model_msg}) | |
| model_inputs = tokenizer.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_tensors="pt", | |
| return_dict=True | |
| ).to(next(model.parameters()).device) | |
| streamer = TextIteratorStreamer( | |
| tokenizer=tokenizer, | |
| timeout=60, | |
| skip_prompt=True, | |
| skip_special_tokens=True | |
| ) | |
| generate_kwargs = { | |
| **model_inputs, | |
| "streamer": streamer, | |
| "max_new_tokens": max_length, | |
| "do_sample": True, | |
| "top_p": top_p, | |
| "temperature": temperature, | |
| "stopping_criteria": StoppingCriteriaList([stop]), | |
| "repetition_penalty": 1.2, | |
| "eos_token_id": [151329, 151336, 151338], | |
| } | |
| t = Thread(target=model.generate, kwargs=generate_kwargs) | |
| t.start() | |
| print("GLM-4:", end="", flush=True) | |
| for new_token in streamer: | |
| if new_token: | |
| print(new_token, end="", flush=True) | |
| history[-1][1] += new_token | |
| history[-1][1] = history[-1][1].strip() | |