Spaces:
Build error
Build error
| from tqdm import tqdm | |
| import requests | |
| import os | |
| import tempfile | |
| def download(ckpt_dir, url): | |
| name = url[url.rfind('/') + 1 : url.rfind('?')] | |
| if ckpt_dir is None: | |
| ckpt_dir = tempfile.gettempdir() | |
| ckpt_dir = os.path.join(ckpt_dir, 'flaxmodels') | |
| ckpt_file = os.path.join(ckpt_dir, name) | |
| if not os.path.exists(ckpt_file): | |
| print(f'Downloading: \"{url[:url.rfind("?")]}\" to {ckpt_file}') | |
| if not os.path.exists(ckpt_dir): | |
| os.makedirs(ckpt_dir) | |
| response = requests.get(url, stream=True) | |
| total_size_in_bytes = int(response.headers.get('content-length', 0)) | |
| progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) | |
| # first create temp file, in case the download fails | |
| ckpt_file_temp = os.path.join(ckpt_dir, name + '.temp') | |
| with open(ckpt_file_temp, 'wb') as file: | |
| for data in response.iter_content(chunk_size=1024): | |
| progress_bar.update(len(data)) | |
| file.write(data) | |
| progress_bar.close() | |
| if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: | |
| print('An error occured while downloading, please try again.') | |
| if os.path.exists(ckpt_file_temp): | |
| os.remove(ckpt_file_temp) | |
| else: | |
| # if download was successful, rename the temp file | |
| os.rename(ckpt_file_temp, ckpt_file) | |
| return ckpt_file | |