Spaces:
Runtime error
Runtime error
| import json | |
| import os | |
| import torch | |
| import numpy as np | |
| from leo.model import SequentialGrounder | |
| from leo.utils import LabelConverter, convert_pc_to_box, obj_processing_post, pad_sequence | |
| from torch.utils.data import default_collate | |
| ASSET_DIR = os.path.join(os.getcwd(), 'assets') | |
| CKPT_DIR = os.path.join(os.getcwd(), 'checkpoint/leo') | |
| int2cat = json.load(open(os.path.join(ASSET_DIR, "meta/scannetv2_raw_categories.json"), 'r', encoding="utf-8")) | |
| cat2int = {w: i for i, w in enumerate(int2cat)} | |
| label_converter = LabelConverter(os.path.join(ASSET_DIR, "meta/scannetv2-labels.combined.tsv")) | |
| role_prompt = "You are an AI visual assistant situated in a 3D scene. "\ | |
| "You can perceive (1) an ego-view image (accessible when necessary) and (2) the objects (including yourself) in the scene (always accessible). "\ | |
| "You should properly respond to the USER's instruction according to the given visual information. " | |
| #role_prompt = " " | |
| egoview_prompt = "Ego-view image:" | |
| objects_prompt = "Objects (including you) in the scene:" | |
| task_prompt = "USER: {instruction} ASSISTANT:" | |
| def get_prompt(instruction): | |
| return { | |
| 'prompt_before_obj': role_prompt, | |
| 'prompt_middle_1': egoview_prompt, | |
| 'prompt_middle_2': objects_prompt, | |
| 'prompt_after_obj': task_prompt.format(instruction=instruction), | |
| } | |
| def get_lang(task_item): | |
| task_description = task_item['task_description'] | |
| sentence = task_description | |
| data_dict = get_prompt(task_description) | |
| # scan_id = task_item['scan_id'] | |
| if 'action_steps' in task_item: | |
| action_steps = task_item['action_steps'] | |
| # tgt_object_id = [int(action['target_id']) for action in action_steps] | |
| # tgt_object_name = [action['label'] for action in action_steps] | |
| for action in action_steps: | |
| sentence += ' ' + action['action'] | |
| data_dict['output_gt'] = ' '.join([action['action'] + ' <s>' for action in action_steps]) | |
| # return scan_id, tgt_object_id, tgt_object_name, sentence, data_dict | |
| return data_dict | |
| def load_data(scan_id): | |
| one_scan = {} | |
| # load scan | |
| pcd_data = torch.load(os.path.join(ASSET_DIR, f'inputs/{scan_id}', f'{scan_id}_pcd.pth')) | |
| inst_to_label = torch.load(os.path.join(ASSET_DIR, f'inputs/{scan_id}', f'{scan_id}_inst.pth')) | |
| points, colors, instance_labels = pcd_data[0], pcd_data[1], pcd_data[-1] | |
| colors = colors / 127.5 - 1 | |
| pcds = np.concatenate([points, colors], 1) | |
| one_scan['pcds'] = pcds | |
| one_scan['instance_labels'] = instance_labels | |
| one_scan['inst_to_label'] = inst_to_label | |
| # convert to gt object | |
| obj_pcds = [] | |
| inst_ids = [] | |
| inst_labels = [] | |
| bg_indices = np.full((points.shape[0], ), 1, dtype=np.bool_) | |
| for inst_id in inst_to_label.keys(): | |
| if inst_to_label[inst_id] in cat2int.keys(): | |
| mask = instance_labels == inst_id | |
| if np.sum(mask) == 0: | |
| continue | |
| obj_pcds.append(pcds[mask]) | |
| inst_ids.append(inst_id) | |
| inst_labels.append(cat2int[inst_to_label[inst_id]]) | |
| if inst_to_label[inst_id] not in ['wall', 'floor', 'ceiling']: | |
| bg_indices[mask] = False | |
| one_scan['obj_pcds'] = obj_pcds | |
| one_scan['inst_labels'] = inst_labels | |
| one_scan['inst_ids'] = inst_ids | |
| one_scan['bg_pcds'] = pcds[bg_indices] | |
| # calculate box for matching | |
| obj_center = [] | |
| obj_box_size = [] | |
| for obj_pcd in obj_pcds: | |
| _c, _b = convert_pc_to_box(obj_pcd) | |
| obj_center.append(_c) | |
| obj_box_size.append(_b) | |
| one_scan['obj_loc'] = obj_center | |
| one_scan['obj_box'] = obj_box_size | |
| # load point feat | |
| feat_pth = os.path.join(ASSET_DIR, f'inputs/{scan_id}', 'obj_feats.pth') | |
| one_scan['obj_feats'] = torch.load(feat_pth, map_location='cpu') | |
| # convert to pq3d input | |
| obj_labels = one_scan['inst_labels'] # N | |
| obj_pcds = one_scan['obj_pcds'] | |
| obj_ids = one_scan['inst_ids'] | |
| # object filter | |
| excluded_labels = ['wall', 'floor', 'ceiling'] | |
| def keep_obj(i, obj_label): | |
| category = int2cat[obj_label] | |
| # filter out background | |
| if category in excluded_labels: | |
| return False | |
| # filter out objects not mentioned in the sentence | |
| return True | |
| selected_obj_idxs = [i for i, obj_label in enumerate(obj_labels) if keep_obj(i, obj_label)] | |
| # crop objects to max_obj_len and reorganize ids ? # TODO | |
| obj_labels = [obj_labels[i] for i in selected_obj_idxs] | |
| obj_pcds = [obj_pcds[i] for i in selected_obj_idxs] | |
| # subsample points | |
| obj_pcds = np.array([obj_pcd[np.random.choice(len(obj_pcd), size=1024, | |
| replace=len(obj_pcd) < 1024)] for obj_pcd in obj_pcds]) | |
| obj_fts, obj_locs, obj_boxes, rot_matrix = obj_processing_post(obj_pcds, rot_aug=False) | |
| data_dict = { | |
| "scan_id": scan_id, | |
| "obj_fts": obj_fts.float(), | |
| "obj_locs": obj_locs.float(), | |
| "obj_labels": torch.LongTensor(obj_labels), | |
| "obj_boxes": obj_boxes, | |
| "obj_pad_masks": torch.ones((len(obj_locs)), dtype=torch.bool), # used for padding in collate | |
| "obj_ids": torch.LongTensor([obj_ids[i] for i in selected_obj_idxs]) | |
| } | |
| # convert point feature | |
| data_dict['obj_feats'] = one_scan['obj_feats'].squeeze(0) | |
| useful_keys = ['tgt_object_id', 'scan_id', 'obj_labels', 'data_idx', | |
| 'obj_fts', 'obj_locs', 'obj_pad_masks', 'obj_ids', | |
| 'source', 'prompt_before_obj', 'prompt_middle_1', | |
| 'prompt_middle_2', 'prompt_after_obj', 'output_gt', 'obj_feats'] | |
| for k in list(data_dict.keys()): | |
| if k not in useful_keys: | |
| del data_dict[k] | |
| # add new keys because of leo | |
| data_dict['img_fts'] = torch.zeros(3, 224, 224) | |
| data_dict['img_masks'] = torch.LongTensor([0]).bool() | |
| data_dict['anchor_locs'] = torch.zeros(3) | |
| data_dict['anchor_orientation'] = torch.zeros(4) | |
| data_dict['anchor_orientation'][-1] = 1 # xyzw | |
| # convert to leo format | |
| data_dict['obj_masks'] = data_dict['obj_pad_masks'] | |
| del data_dict['obj_pad_masks'] | |
| return data_dict | |
| def form_batch(data_dict): | |
| batch = [data_dict] | |
| new_batch = {} | |
| # pad | |
| padding_keys = ['obj_fts', 'obj_locs', 'obj_masks', 'obj_labels', 'obj_ids'] | |
| for k in padding_keys: | |
| tensors = [sample.pop(k) for sample in batch] | |
| padded_tensor = pad_sequence(tensors, pad=0) | |
| new_batch[k] = padded_tensor | |
| # # list | |
| # list_keys = ['tgt_object_id'] | |
| # for k in list_keys: | |
| # new_batch[k] = [sample.pop(k) for sample in batch] | |
| # default collate | |
| new_batch.update(default_collate(batch)) | |
| return new_batch | |
| def inference(scan_id, task, predict_mode=False): | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| # device = 'cpu' # ok for predict_mode=False, and both for Gradio demo local preview | |
| data_dict = load_data(scan_id) | |
| data_dict.update(get_lang(task)) | |
| data_dict = form_batch(data_dict) | |
| for key, value in data_dict.items(): | |
| if isinstance(value, torch.Tensor): | |
| data_dict[key] = value.to(device) | |
| model = SequentialGrounder(predict_mode) | |
| load_msg = model.load_state_dict(torch.load(os.path.join(CKPT_DIR, 'pytorch_model.bin'), map_location='cpu'), strict=False) | |
| model.to(device) | |
| data_dict = model(data_dict) | |
| if predict_mode == False: | |
| # calculate result id | |
| result_id_list = [data_dict['obj_ids'][0][torch.argmax(data_dict['ground_logits'][i]).item()] | |
| for i in range(len(data_dict['ground_logits']))] | |
| else: | |
| # calculate langauge | |
| # tgt_object_id = data_dict['tgt_object_id'] | |
| if data_dict['ground_logits'] == None: | |
| og_pred = [] | |
| else: | |
| og_pred = torch.argmax(data_dict['ground_logits'], dim=1) | |
| grd_batch_ind_list = data_dict['grd_batch_ind_list'] | |
| response_pred = [] | |
| for i in range(1): # len(tgt_object_id) | |
| # target_sequence = list(tgt_object_id[i].cpu().numpy()) | |
| predict_sequence = [] | |
| if og_pred != None: | |
| for j in range(len(og_pred)): | |
| if grd_batch_ind_list[j] == i: | |
| predict_sequence.append(og_pred[j].item()) | |
| obj_ids = data_dict['obj_ids'] | |
| response_pred.append({ | |
| 'predict_object_id' : [obj_ids[i][o].item() for o in predict_sequence], | |
| 'predict_object_id': [obj_ids[i][o].item() for o in predict_sequence], | |
| 'pred_plan_text': data_dict['output_txt'][i] | |
| }) | |
| return result_id_list if predict_mode == False else response_pred | |
| if __name__ == '__main__': | |
| inference("scene0050_00", { | |
| "task_description": "Find the chair and move it to the table.", | |
| "action_steps": [ | |
| { | |
| "target_id": "1", | |
| "label": "chair", | |
| "action": "Find the chair." | |
| }, | |
| { | |
| "target_id": "2", | |
| "label": "table", | |
| "action": "Move the chair to the table." | |
| } | |
| ], | |
| "scan_id": "scene0050_00" | |
| }, predict_mode=True) | |