Spaces:
Runtime error
Runtime error
| import torch | |
| import argparse | |
| import PIL | |
| from PIL import Image | |
| import os | |
| from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | |
| from conversation import conv_templates, SeparatorStyle | |
| from torchvision import transforms | |
| from constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN | |
| from threading import Thread | |
| from unitok.config import Args | |
| from unitok.model import UniTok | |
| from model.builder import load_pretrained_model | |
| from mm_utils import tokenizer_image_token, get_model_name_from_path | |
| IMAGE_TOKEN_INDEX=-200 | |
| def expand2square(pil_img, background_color): | |
| width, height = pil_img.size | |
| if width == height: | |
| return pil_img | |
| elif width > height: | |
| result = Image.new(pil_img.mode, (width, width), background_color) | |
| result.paste(pil_img, (0, (width - height) // 2)) | |
| return result | |
| else: | |
| result = Image.new(pil_img.mode, (height, height), background_color) | |
| result.paste(pil_img, ((height - width) // 2, 0)) | |
| return result | |
| def main(args): | |
| ckpt = torch.load(args.unitok_path, map_location='cpu') | |
| vae_cfg = Args() | |
| vae_cfg.load_state_dict(ckpt['args']) | |
| vq_model = UniTok(vae_cfg) | |
| vq_model.load_state_dict(ckpt['trainer']['unitok']) | |
| vq_model.to('cuda') | |
| vq_model.eval() | |
| model_path = os.path.expanduser(args.mllm_path) | |
| model_name = get_model_name_from_path(model_path) | |
| tokenizer, vqllm, image_processor, context_len = load_pretrained_model(model_path, model_name, load_8bit=args.load_8bit) | |
| qs = args.prompt | |
| qs = '<boi><image><eoi>' + '\n' + qs | |
| conv = conv_templates['llava_v1'].copy() | |
| conv.append_message(conv.roles[0], qs) | |
| conv.append_message(conv.roles[1], None) | |
| prompt = conv.get_prompt() | |
| crop_size = 256 | |
| transform = transforms.Compose([ | |
| transforms.Resize((crop_size, crop_size)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) | |
| ]) | |
| print(prompt) | |
| image = Image.open(args.image_path).convert('RGB') | |
| pad_image = expand2square(image, (122, 116, 104) ) | |
| # import pdb;pdb.set_trace() | |
| img = transform(pad_image).unsqueeze(0) | |
| img = img.to('cuda') | |
| # import pdb;pdb.set_trace() | |
| with torch.no_grad(): | |
| vq_code = vq_model.img_to_idx(img) | |
| image_codes = vq_code.unsqueeze(0) | |
| input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt') | |
| # input_ids = torch.cat(text_ids, dim=0) | |
| # input_embeddings = vqllm.embed_tokens(input_ids) | |
| inputs = { | |
| "inputs":input_ids.unsqueeze(0).to("cuda:0"), | |
| "images":image_codes.to("cuda:0"), | |
| "max_new_tokens":1024, | |
| "bos_token_id":tokenizer.bos_token_id, # Begin of sequence token | |
| "eos_token_id":tokenizer.eos_token_id, # End of sequence token | |
| "pad_token_id":tokenizer.pad_token_id, # Pad token | |
| } | |
| streamer = TextIteratorStreamer(tokenizer, **{"skip_special_tokens": True, "skip_prompt": True}) | |
| # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way. | |
| generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024) | |
| thread = Thread(target=vqllm.generate_mllm, kwargs=generation_kwargs) | |
| thread.start() | |
| generated_text = "" | |
| for new_text in streamer: | |
| generated_text += new_text | |
| print(generated_text) | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser(description='Process some integers.') | |
| parser.add_argument('--unitok_path', type=str, default=r'D:\projects\liquid_app\UniTok\UniTok_weights\unitok_tokenizer\unitok_tokenizer.pth',required=False) | |
| parser.add_argument('--mllm_path', type=str, default= r'C:\debug_ckpts\unitok_mllm', required=False) | |
| parser.add_argument('--prompt', type=str, required=True, help='input text prompt') | |
| parser.add_argument('--image_path', type=str, required=True, help='input image path') | |
| parser.add_argument('--load_8bit', action='store_true', default=False, help='use 8bit to save memory') | |
| args = parser.parse_args() | |
| main(args) | |