Spaces:
Build error
Build error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| import os | |
| import json | |
| import argparse | |
| from PIL import Image | |
| import numpy as np | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--ann', default='datasets/cc3m/Train_GCC-training.tsv') | |
| parser.add_argument('--save_image_path', default='datasets/cc3m/training/') | |
| parser.add_argument('--cat_info', default='datasets/lvis/lvis_v1_val.json') | |
| parser.add_argument('--out_path', default='datasets/cc3m/train_image_info.json') | |
| parser.add_argument('--not_download_image', action='store_true') | |
| args = parser.parse_args() | |
| categories = json.load(open(args.cat_info, 'r'))['categories'] | |
| images = [] | |
| if not os.path.exists(args.save_image_path): | |
| os.makedirs(args.save_image_path) | |
| f = open(args.ann) | |
| for i, line in enumerate(f): | |
| cap, path = line[:-1].split('\t') | |
| print(i, cap, path) | |
| if not args.not_download_image: | |
| os.system( | |
| 'wget {} -O {}/{}.jpg'.format( | |
| path, args.save_image_path, i + 1)) | |
| try: | |
| img = Image.open( | |
| open('{}/{}.jpg'.format(args.save_image_path, i + 1), "rb")) | |
| img = np.asarray(img.convert("RGB")) | |
| h, w = img.shape[:2] | |
| except: | |
| continue | |
| image_info = { | |
| 'id': i + 1, | |
| 'file_name': '{}.jpg'.format(i + 1), | |
| 'height': h, | |
| 'width': w, | |
| 'captions': [cap], | |
| } | |
| images.append(image_info) | |
| data = {'categories': categories, 'images': images, 'annotations': []} | |
| for k, v in data.items(): | |
| print(k, len(v)) | |
| print('Saving to', args.out_path) | |
| json.dump(data, open(args.out_path, 'w')) | |