Spaces:
Runtime error
Runtime error
| # -*- encoding: utf-8 -*- | |
| import os | |
| import sys | |
| import torch | |
| import argparse | |
| from transformers import AutoTokenizer | |
| from sat.model.mixins import CachedAutoregressiveMixin | |
| from sat.quantization.kernels import quantize | |
| from model import VisualGLMModel, chat | |
| from finetune_visualglm import FineTuneVisualGLMModel | |
| from sat.model import AutoModel | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--max_length", type=int, default=2048, help='max length of the total sequence') | |
| parser.add_argument("--top_p", type=float, default=0.4, help='top p for nucleus sampling') | |
| parser.add_argument("--top_k", type=int, default=100, help='top k for top k sampling') | |
| parser.add_argument("--temperature", type=float, default=.8, help='temperature for sampling') | |
| parser.add_argument("--english", action='store_true', help='only output English') | |
| parser.add_argument("--quant", choices=[8, 4], type=int, default=None, help='quantization bits') | |
| parser.add_argument("--from_pretrained", type=str, default="visualglm-6b", help='pretrained ckpt') | |
| parser.add_argument("--prompt_zh", type=str, default="描述这张图片。", help='Chinese prompt for the first round') | |
| parser.add_argument("--prompt_en", type=str, default="Describe the image.", help='English prompt for the first round') | |
| args = parser.parse_args() | |
| # load model | |
| model, model_args = AutoModel.from_pretrained( | |
| args.from_pretrained, | |
| args=argparse.Namespace( | |
| fp16=True, | |
| skip_init=True, | |
| use_gpu_initialization=True if (torch.cuda.is_available() and args.quant is None) else False, | |
| device='cuda' if (torch.cuda.is_available() and args.quant is None) else 'cpu', | |
| )) | |
| model = model.eval() | |
| if args.quant: | |
| quantize(model.transformer, args.quant) | |
| if torch.cuda.is_available(): | |
| model = model.cuda() | |
| model.add_mixin('auto-regressive', CachedAutoregressiveMixin()) | |
| tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) | |
| if not args.english: | |
| print('欢迎使用 VisualGLM-6B 模型,输入图像URL或本地路径读图,继续输入内容对话,clear 重新开始,stop 终止程序') | |
| else: | |
| print('Welcome to VisualGLM-6B model. Enter an image URL or local file path to load an image. Continue inputting text to engage in a conversation. Type "clear" to start over, or "stop" to end the program.') | |
| with torch.no_grad(): | |
| while True: | |
| history = None | |
| cache_image = None | |
| if not args.english: | |
| image_path = input("请输入图像路径或URL(回车进入纯文本对话): ") | |
| else: | |
| image_path = input("Please enter the image path or URL (press Enter for plain text conversation): ") | |
| if image_path == 'stop': | |
| break | |
| if len(image_path) > 0: | |
| query = args.prompt_en if args.english else args.prompt_zh | |
| else: | |
| if not args.english: | |
| query = input("用户:") | |
| else: | |
| query = input("User: ") | |
| while True: | |
| if query == "clear": | |
| break | |
| if query == "stop": | |
| sys.exit(0) | |
| try: | |
| response, history, cache_image = chat( | |
| image_path, | |
| model, | |
| tokenizer, | |
| query, | |
| history=history, | |
| image=cache_image, | |
| max_length=args.max_length, | |
| top_p=args.top_p, | |
| temperature=args.temperature, | |
| top_k=args.top_k, | |
| english=args.english, | |
| invalid_slices=[slice(63823, 130000)] if args.english else [] | |
| ) | |
| except Exception as e: | |
| print(e) | |
| break | |
| sep = 'A:' if args.english else '答:' | |
| print("VisualGLM-6B:"+response.split(sep)[-1].strip()) | |
| image_path = None | |
| if not args.english: | |
| query = input("用户:") | |
| else: | |
| query = input("User: ") | |
| if __name__ == "__main__": | |
| main() |