Spaces:
Runtime error
Runtime error
| import contextlib | |
| import clip | |
| import torch | |
| import torch.nn as nn | |
| from einops import rearrange | |
| from peft import LoraConfig, get_peft_model | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM, LlamaTokenizer | |
| from leo.img_encoder import GridFeatureExtractor2D | |
| from leo.pcd_encoder import OSE3D | |
| from leo.grounding_head import SequentialGroundHead | |
| from leo.utils import get_mlp_head | |
| def maybe_autocast(model, dtype='bf16', enabled=True): | |
| # if on cpu, don't use autocast | |
| # if on gpu, use autocast with dtype if provided, otherwise use torch.float16 | |
| enable_autocast = model.device != torch.device('cpu') | |
| if dtype == 'bf16': | |
| dtype = torch.bfloat16 | |
| elif dtype == 'fp16': | |
| dtype == torch.float16 | |
| else: | |
| dtype = torch.float32 | |
| if enable_autocast: | |
| return torch.cuda.amp.autocast(dtype=dtype, enabled=enabled) | |
| else: | |
| return contextlib.nullcontext() | |
| def disabled_train(self, mode=True): | |
| """ | |
| Overwrite model.train with this function to make sure train/eval mode does not change anymore | |
| """ | |
| return self | |
| class SequentialGrounder(torch.nn.Module): | |
| def __init__(self,predict_mode=False): | |
| super().__init__() | |
| cfg = { | |
| "launch_mode": "hf", | |
| "model": { | |
| "llm": { | |
| "name": "Vicuna7B", | |
| "cfg_path": "/scratch/generalvision/vicuna-7b", | |
| "hf_cfg_path": "huangjy-pku/vicuna-7b", | |
| "truncation_side": "right", | |
| "max_context_len": 256, | |
| "max_out_len": 256, | |
| "lora": { | |
| "flag": True, | |
| "rank": 16, | |
| "alpha": 16, | |
| "dropout": 0.0, | |
| "target_modules": ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'], | |
| }, | |
| }, | |
| "clip_txt_guidance": { | |
| "flag": False, | |
| "clip_out_dim": 1024, | |
| }, | |
| }, | |
| } | |
| self.predict_mode = predict_mode | |
| # LLM | |
| llm_name = cfg['model']['llm']['name'] | |
| if cfg['launch_mode'] == 'hf': | |
| llm_cfg_path = cfg['model']['llm']['hf_cfg_path'] | |
| else: | |
| llm_cfg_path = cfg['model']['llm']['cfg_path'] | |
| llm_truncation_side = 'right' | |
| if 'vicuna' in llm_name.lower(): | |
| self.llm_tokenizer = LlamaTokenizer.from_pretrained(llm_cfg_path, truncation_side=llm_truncation_side) | |
| self.llm_tokenizer.add_special_tokens({'pad_token': '[PAD]'}) | |
| self.llm_model = LlamaForCausalLM.from_pretrained(llm_cfg_path, torch_dtype=torch.float16) | |
| self.llm_model.resize_token_embeddings(len(self.llm_tokenizer)) | |
| else: | |
| self.llm_tokenizer = AutoTokenizer.from_pretrained(llm_cfg_path, truncation_side=llm_truncation_side) | |
| self.llm_model = AutoModelForCausalLM.from_pretrained(llm_cfg_path, torch_dtype=torch.float16) | |
| for param in self.llm_model.parameters(): | |
| param.requires_grad = False | |
| self.llm_model.eval() | |
| self.llm_model.train = disabled_train | |
| # 2D vision | |
| self.img_encoder = GridFeatureExtractor2D() | |
| self.img_proj = nn.Linear( | |
| self.img_encoder.out_channels, self.llm_model.config.hidden_size | |
| ) | |
| # 3D vision | |
| self.pcd_encoder = OSE3D() | |
| self.pcd_proj = nn.Linear(256, self.llm_model.config.hidden_size) | |
| # type embedding | |
| # self.img_type_embed = nn.Parameter(torch.zeros(self.llm_model.config.hidden_size), requires_grad=True) | |
| # self.pcd_type_embed = nn.Parameter(torch.zeros(self.llm_model.config.hidden_size), requires_grad=True) | |
| # LoRA | |
| if cfg['model']['llm']['lora']['flag']: | |
| lora_config = LoraConfig( | |
| r=cfg['model']['llm']['lora']['rank'], | |
| lora_alpha=cfg['model']['llm']['lora']['alpha'], | |
| target_modules=cfg['model']['llm']['lora']['target_modules'], | |
| lora_dropout=cfg['model']['llm']['lora']['dropout'], | |
| bias='none', | |
| modules_to_save=[], | |
| ) | |
| self.llm_model = get_peft_model(self.llm_model, peft_config=lora_config) | |
| self.max_context_len = 256 | |
| self.max_out_len = 256 | |
| # additional text x multi-modal tokens fusion | |
| self.clip_txt_guidance = cfg['model']['clip_txt_guidance']['flag'] | |
| if self.clip_txt_guidance: | |
| self.clip_model = clip.load('RN50')[0] | |
| for param in self.clip_model.parameters(): | |
| param.requires_grad = False | |
| self.clip_model.eval() | |
| self.clip_model.train = disabled_train | |
| self.clip_proj = nn.Linear(cfg['clip_txt_guidance']['clip_out_dim'], self.llm_model.config.hidden_size) | |
| # grounding head | |
| self.ground_head = SequentialGroundHead() | |
| self.obj_cls_head = get_mlp_head(4096, 768, 607, 0.3) | |
| self.pre_grounding = True | |
| def device(self): | |
| return list(self.parameters())[0].device | |
| def build_right_justified_sequence(self, data_dict): | |
| """ | |
| Concat six sequences: `prompt_before_obj`, `prompt_middle_1`, `img_tokens`, `prompt_middle_2`, `obj_tokens`, `prompt_after_obj`. | |
| Return right justified sequence for causal LM: <pad>, <role/situation>, <img>, <objs>, <instruction>. | |
| """ | |
| device = self.device | |
| bs = len(data_dict['prompt_before_obj']) | |
| self.llm_tokenizer.padding_side = 'left' | |
| text_input_tokens_pre = self.llm_tokenizer( | |
| data_dict['prompt_before_obj'], | |
| return_tensors='pt', | |
| padding='longest' | |
| ).to(device) # [PAD, BOS, tokens], (B, T1) | |
| text_input_tokens_mid1 = self.llm_tokenizer( | |
| data_dict['prompt_middle_1'], | |
| return_tensors='pt', | |
| padding='longest' | |
| ).to(device) | |
| img_tokens = data_dict['img_tokens'].to(device) | |
| img_masks = data_dict['img_masks'].to(device) | |
| img_masks = img_masks.reshape(-1, 1).repeat(1, img_tokens.size(1)) | |
| text_input_tokens_mid2 = self.llm_tokenizer( | |
| data_dict['prompt_middle_2'], | |
| return_tensors='pt', | |
| padding='longest' | |
| ).to(device) | |
| obj_tokens = data_dict['obj_tokens'].to(device) | |
| obj_masks = data_dict['obj_masks'].to(device) | |
| # additional clip fusion | |
| if self.clip_txt_guidance: | |
| with torch.no_grad(): | |
| clip_fts = self.clip_model.encode_text( | |
| clip.tokenize(data_dict['prompt_after_obj'], truncate=True).to(device) | |
| ) | |
| clip_fts = self.clip_proj(clip_fts) | |
| # B, N, C | |
| img_tokens = torch.einsum('bnc,bc->bnc', img_tokens, clip_fts) | |
| obj_tokens = torch.einsum('bnc,bc->bnc', obj_tokens, clip_fts) | |
| self.llm_tokenizer.padding_side = 'right' # no need to be 'left', as padding tokens will be shifted | |
| self.llm_tokenizer.truncation_side = 'left' # truncate history | |
| text_input_tokens_post = self.llm_tokenizer( | |
| data_dict['prompt_after_obj'], | |
| return_tensors='pt', | |
| padding='longest', | |
| truncation=True, | |
| max_length=self.max_context_len, | |
| ).to(device) # [BOS, tokens, PAD], (B, T3) | |
| assert text_input_tokens_mid1.attention_mask.all() and text_input_tokens_mid2.attention_mask.all(), \ | |
| "prompt_middle should be the same and thus no padding" | |
| # remove bos, make "tokenize subseq and concat" equivalent to "tokenize the whole seq" | |
| text_input_tokens_mid1.input_ids = text_input_tokens_mid1.input_ids[:, 1:] | |
| text_input_tokens_mid1.attention_mask = text_input_tokens_mid1.attention_mask[:, 1:] | |
| text_input_tokens_mid2.input_ids = text_input_tokens_mid2.input_ids[:, 1:] | |
| text_input_tokens_mid2.attention_mask = text_input_tokens_mid2.attention_mask[:, 1:] | |
| text_input_tokens_post.input_ids = text_input_tokens_post.input_ids[:, 1:] | |
| text_input_tokens_post.attention_mask = text_input_tokens_post.attention_mask[:, 1:] | |
| for i in range(bs): | |
| if not img_masks[i].any(): | |
| # no image input, also mask the text prompt for image tokens | |
| text_input_tokens_mid1.attention_mask[i].fill_(0) | |
| inputs_embeds_pre = self.llm_model.get_input_embeddings()(text_input_tokens_pre.input_ids) | |
| inputs_embeds_mid1 = self.llm_model.get_input_embeddings()(text_input_tokens_mid1.input_ids) | |
| inputs_embeds_mid2 = self.llm_model.get_input_embeddings()(text_input_tokens_mid2.input_ids) | |
| inputs_embeds_post = self.llm_model.get_input_embeddings()(text_input_tokens_post.input_ids) | |
| # since img_tokens, prompt_mid, obj_tokens are fixed length without padding, we concat them first | |
| inputs_embeds_mid = torch.cat([inputs_embeds_mid1, img_tokens, inputs_embeds_mid2, obj_tokens], dim=1) | |
| attn_mask_mid = torch.cat( | |
| [text_input_tokens_mid1.attention_mask, img_masks, text_input_tokens_mid2.attention_mask, obj_masks], | |
| dim=1, | |
| ) | |
| post_pad_length = torch.logical_not(text_input_tokens_post.attention_mask).sum(-1) | |
| bs, l1, hidden_dim = inputs_embeds_pre.shape | |
| _, l2, _ = inputs_embeds_mid.shape | |
| _, l3, _ = inputs_embeds_post.shape | |
| inputs_embeds = torch.zeros(bs, l1+l2+l3, hidden_dim).type(inputs_embeds_pre.dtype).to(device) | |
| attention_mask = torch.zeros(bs, l1+l2+l3).type(obj_masks.dtype).to(device) | |
| # assign by chunks | |
| for i in range(bs): | |
| post_pad_len = post_pad_length[i] | |
| if post_pad_len > 0: | |
| inputs_embeds[i, :post_pad_len] = inputs_embeds_post[i, -post_pad_len:] | |
| attention_mask[i, :post_pad_len] = 0 | |
| inputs_embeds[i, post_pad_len+l1+l2:] = inputs_embeds_post[i, :-post_pad_len] | |
| attention_mask[i, post_pad_len+l1+l2:] = 1 | |
| else: | |
| # no padding | |
| inputs_embeds[i, -l3:] = inputs_embeds_post[i] | |
| attention_mask[i, -l3:] = 1 | |
| inputs_embeds[i, post_pad_len: post_pad_len+l1] = inputs_embeds_pre[i] | |
| attention_mask[i, post_pad_len: post_pad_len+l1] = text_input_tokens_pre.attention_mask[i] | |
| inputs_embeds[i, post_pad_len+l1: post_pad_len+l1+l2] = inputs_embeds_mid[i] | |
| attention_mask[i, post_pad_len+l1: post_pad_len+l1+l2] = attn_mask_mid[i] | |
| return inputs_embeds, attention_mask, (l1, l2, l3) | |
| def forward(self, data_dict): | |
| if self.predict_mode: | |
| return self.generate(data_dict=data_dict) | |
| """ | |
| data_dict requires keys: | |
| # input | |
| prompt_before_obj: list of str, (B,) | |
| prompt_middle_1: list of str, (B,) | |
| prompt_middle_2: list of str, (B,) | |
| prompt_after_obj: list of str, (B,) | |
| obj_fts: (B, N, P, 6), xyz + rgb | |
| obj_masks: (B, N), 1 valid and 0 masked | |
| obj_locs: (B, N, 6), xyz + whd | |
| anchor_locs: (B, 3) | |
| anchor_orientation: (B, C) | |
| img_fts: (B, 3, H, W), rgb | |
| img_masks: (B, 1), 1 valid and 0 masked | |
| # output | |
| output_gt: list of str, (B,) | |
| """ | |
| device = self.device | |
| bs = len(data_dict['prompt_after_obj']) | |
| data_dict['bs'] = bs | |
| if 'obj_tokens' not in data_dict: | |
| # obtain obj tokens | |
| data_dict = self.pcd_encoder(data_dict) | |
| # TO CHANGE FOR DEBUG | |
| #self.llm_model.float() | |
| #data_dict['obj_tokens'] = torch.zeros((data_dict['obj_locs'].shape[0], data_dict['obj_locs'].shape[1], 256)).to(device=device) | |
| data_dict['obj_tokens'] = self.pcd_proj(data_dict['obj_tokens'].to(device)) | |
| # data_dict['obj_tokens'] = data_dict['obj_tokens'] + self.pcd_type_embed | |
| data_dict['img_tokens'] = self.img_proj(self.img_encoder(data_dict['img_fts'])) | |
| # data_dict['img_tokens'] = data_dict['img_tokens'] + self.img_type_embed | |
| # build input embdes and record prompt position | |
| inputs_embeds, attention_mask, input_length = self.build_right_justified_sequence(data_dict=data_dict) | |
| obj_token_length = data_dict['obj_masks'].shape[1] | |
| # (B, T1+O+T2, D), (B, T1+O+T2) | |
| self.llm_tokenizer.padding_side = 'right' | |
| self.llm_tokenizer.truncation_side = 'right' | |
| text_output_tokens = self.llm_tokenizer( | |
| [t + self.llm_tokenizer.eos_token for t in data_dict['output_gt']], | |
| return_tensors='pt', | |
| padding='longest', | |
| truncation=True, | |
| max_length=self.max_out_len, | |
| ).to(device) | |
| # record position for special token [SOS] | |
| grd_token_id = self.llm_tokenizer.convert_tokens_to_ids(['<s>'])[0] | |
| out_input_ids_remove_first_sos = text_output_tokens.input_ids.clone() | |
| out_input_ids_remove_first_sos[:, 0] = -100 | |
| grd_ind_0, grd_ind_1 = (out_input_ids_remove_first_sos == grd_token_id).nonzero(as_tuple=True) | |
| text_output_embeds = self.llm_model.get_input_embeddings()(text_output_tokens.input_ids) # (B, T3, D) | |
| inputs_embeds = torch.cat([inputs_embeds, text_output_embeds], dim=1) # (B, T1+O+T2+T3, D) | |
| attention_mask = torch.cat([attention_mask, text_output_tokens.attention_mask], dim=1) # (B, T1+O+T2+T3) | |
| # construct targets | |
| targets = torch.zeros_like(attention_mask).long().fill_(-100) # (B, T1+O+T2+T3) | |
| # only apply loss to answer tokens | |
| targets_idx = text_output_tokens.attention_mask.bool() | |
| targets[:, -targets_idx.shape[1]:][targets_idx] = text_output_tokens.input_ids[targets_idx] | |
| # do not predict bos token, regard it as condition instead | |
| targets[:, -targets_idx.shape[1]] = -100 | |
| with maybe_autocast(self): | |
| outputs = self.llm_model( | |
| inputs_embeds=inputs_embeds, | |
| attention_mask=attention_mask, | |
| return_dict=True, | |
| output_hidden_states=True, | |
| ) | |
| logits = outputs.logits.float() | |
| last_hidden_state = outputs.hidden_states[-1] | |
| # different from the loss inside `llm_model.forward`, here we take mean of each sequence instead of sum | |
| shift_logits = logits[..., :-1, :].contiguous() | |
| shift_labels = targets[..., 1:].contiguous() | |
| num_tokens_for_loss = (shift_labels >= 0).int().sum(1) # (B,) | |
| shift_logits = rearrange(shift_logits, 'b t v -> (b t) v') | |
| shift_labels = rearrange(shift_labels, 'b t -> (b t)') | |
| shift_labels = shift_labels.to(shift_logits.device) | |
| # record for llm loss | |
| data_dict['llm_logits'] = shift_logits | |
| data_dict['llm_labels'] = shift_labels | |
| data_dict['num_tokens_for_loss'] = num_tokens_for_loss | |
| # record for grounding loss | |
| grd_list = [] | |
| obj_list = [] | |
| mask_list = [] | |
| for step in range(len(grd_ind_0)): | |
| batch_ind = grd_ind_0[step] | |
| grd_token_ind = grd_ind_1[step] | |
| if self.pre_grounding: | |
| output_obj_tokens = data_dict['obj_tokens'][batch_ind] | |
| else: | |
| output_obj_tokens = last_hidden_state[batch_ind, input_length[0] + input_length[1] - obj_token_length : input_length[0] + input_length[1], :] | |
| output_grd_tokens = last_hidden_state[batch_ind, sum(input_length) + grd_token_ind:sum(input_length) + grd_token_ind + 1, :] | |
| grd_list.append(output_grd_tokens) | |
| obj_list.append(output_obj_tokens) | |
| mask_list.append(data_dict['obj_masks'][batch_ind]) | |
| output_obj = torch.stack(obj_list).float() | |
| output_grd = torch.stack(grd_list).float() | |
| data_dict['ground_logits'] = self.ground_head(output_obj, output_grd, torch.stack(mask_list)) | |
| # data_dict['ground_label'] = torch.concat(data_dict['tgt_object_id'], dim=0) | |
| # record for cls loss | |
| #obj_cls_post_embeds = last_hidden_state[:, input_length[0] + input_length[1] - obj_token_length : input_length[0] + input_length[1], :].float() | |
| obj_cls_post_embeds = data_dict['obj_tokens'].float() | |
| data_dict['obj_cls_post_logits'] = self.obj_cls_head(obj_cls_post_embeds) | |
| return data_dict | |
| def generate( | |
| self, | |
| data_dict, | |
| use_nucleus_sampling=False, | |
| num_beams=5, | |
| max_length=256, | |
| min_length=1, | |
| top_p=0.9, | |
| repetition_penalty=6.0, | |
| length_penalty=1, | |
| num_captions=1, | |
| temperature=1, | |
| ): | |
| """ | |
| data_dict requires the same keys as forward() except output_gt | |
| """ | |
| device = self.device | |
| bs = len(data_dict['prompt_after_obj']) | |
| data_dict['bs'] = bs | |
| if 'obj_tokens' not in data_dict: | |
| # obtain obj tokens | |
| data_dict = self.pcd_encoder(data_dict) | |
| # TO CHANGE FOR DEBUG | |
| #self.llm_model.float() | |
| #data_dict['obj_tokens'] = torch.zeros((data_dict['obj_locs'].shape[0], data_dict['obj_locs'].shape[1], 256)).to(device=device) | |
| data_dict['obj_tokens'] = self.pcd_proj(data_dict['obj_tokens'].to(device)) | |
| # data_dict['obj_tokens'] = data_dict['obj_tokens'] + self.pcd_type_embed | |
| data_dict['img_tokens'] = self.img_proj(self.img_encoder(data_dict['img_fts'])) | |
| # data_dict['img_tokens'] = data_dict['img_tokens'] + self.img_type_embed | |
| inputs_embeds, attention_mask, input_length = self.build_right_justified_sequence(data_dict=data_dict) | |
| obj_token_length = data_dict['obj_masks'].shape[1] | |
| # give bos token as condition | |
| bos_tokens = self.llm_tokenizer( | |
| [self.llm_tokenizer.bos_token] * bs, | |
| return_tensors='pt', | |
| ).to(device) | |
| bos_tokens_ids = bos_tokens.input_ids[:, 0:1] # (B, 1) | |
| bos_tokens_attn = bos_tokens.attention_mask[:, 0:1] # (B, 1) | |
| # prepare a `bos_token` | |
| bos_embeds = self.llm_model.get_input_embeddings()(bos_tokens_ids) # (B, 1, D) | |
| inputs_embeds = torch.cat([inputs_embeds, bos_embeds], dim=1) # (B, T1+O+T2+1, D) | |
| attention_mask = torch.cat([attention_mask, bos_tokens_attn], dim=1) # (B, T1+O+T2+1) | |
| with maybe_autocast(self): | |
| outputs = self.llm_model.generate( | |
| inputs_embeds=inputs_embeds, | |
| attention_mask=attention_mask, | |
| do_sample=use_nucleus_sampling, | |
| top_p=top_p, | |
| temperature=temperature, | |
| num_beams=num_beams, | |
| max_length=max_length, | |
| min_length=min_length, | |
| repetition_penalty=repetition_penalty, | |
| length_penalty=length_penalty, | |
| num_return_sequences=num_captions, | |
| return_dict_in_generate=True, | |
| output_hidden_states=True, | |
| output_scores=True | |
| ) | |
| # note output_ids_idx - 1 = step idx, because we do not preduct [BOS] | |
| beam_indices = outputs.beam_indices # bs x step, beam indices range (bsxbeam) | |
| scores = outputs.scores # step x (bs x beam) x vocab | |
| hidden_states = outputs.hidden_states # step x layer x (bs x beam) x token_num x hidden_dim | |
| outputs = outputs.sequences # bs x output_ids | |
| outputs[outputs == self.llm_tokenizer.unk_token_id] = self.llm_tokenizer.eos_token_id | |
| # data_dict['output_tokens'] = outputs # unable to gather variable-length tensors | |
| # record for grounding | |
| grd_token_id = self.llm_tokenizer.convert_tokens_to_ids(['<s>'])[0] | |
| out_input_ids_remove_first_sos = outputs.clone() | |
| out_input_ids_remove_first_sos[:, 0] = -100 | |
| grd_ind_0, grd_ind_1 = (out_input_ids_remove_first_sos == grd_token_id).nonzero(as_tuple=True) | |
| grd_list = [] | |
| grd_batch_ind_list = [] | |
| obj_list = [] | |
| mask_list = [] | |
| if len(grd_ind_0) > 0: | |
| for step in range(len(grd_ind_0)): | |
| batch_ind = grd_ind_0[step] | |
| grd_token_ind = grd_ind_1[step] | |
| #output_obj_tokens = last_hidden_state[batch_ind, input_length[0] + input_length[1] - obj_token_length : input_length[0] + input_length[1], :] | |
| output_obj_tokens = data_dict['obj_tokens'][batch_ind] | |
| output_grd_tokens = hidden_states[grd_token_ind-1][-1][beam_indices[batch_ind, grd_token_ind-1]][-1].unsqueeze(0) # grd_token_ind - 1 because first token is sos | |
| grd_list.append(output_grd_tokens) | |
| grd_batch_ind_list.append(batch_ind) | |
| obj_list.append(output_obj_tokens) | |
| mask_list.append(data_dict['obj_masks'][batch_ind]) | |
| output_obj = torch.stack(obj_list).float() | |
| output_grd = torch.stack(grd_list).float() | |
| data_dict['ground_logits'] = self.ground_head(output_obj, output_grd, torch.stack(mask_list)) | |
| else: | |
| data_dict['ground_logits'] = None | |
| # data_dict['ground_label'] = torch.concat(data_dict['tgt_object_id'], dim=0) | |
| data_dict['grd_batch_ind_list'] = grd_batch_ind_list | |
| output_txt = self.llm_tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
| output_txt = [txt.strip() for txt in output_txt] | |
| data_dict['output_txt'] = output_txt | |
| return data_dict | |