Spaces:
Runtime error
Runtime error
| # Copyright 2020 Erik Härkönen. All rights reserved. | |
| # This file is licensed to you under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. You may obtain a copy | |
| # of the License at http://www.apache.org/licenses/LICENSE-2.0 | |
| # Unless required by applicable law or agreed to in writing, software distributed under | |
| # the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS | |
| # OF ANY KIND, either express or implied. See the License for the specific language | |
| # governing permissions and limitations under the License. | |
| import string | |
| import numpy as np | |
| from pathlib import Path | |
| import requests | |
| import pickle | |
| import sys | |
| import re | |
| import gdown | |
| def prettify_name(name): | |
| valid = "-_%s%s" % (string.ascii_letters, string.digits) | |
| return ''.join(map(lambda c : c if c in valid else '_', name)) | |
| # Add padding to sequence of images | |
| # Used in conjunction with np.hstack/np.vstack | |
| # By default: adds one 64th of the width of horizontal padding | |
| def pad_frames(strip, pad_fract_horiz=64, pad_fract_vert=0, pad_value=None): | |
| dtype = strip[0].dtype | |
| if pad_value is None: | |
| if dtype in [np.float32, np.float64]: | |
| pad_value = 1.0 | |
| else: | |
| pad_value = np.iinfo(dtype).max | |
| frames = [strip[0]] | |
| for frame in strip[1:]: | |
| if pad_fract_horiz > 0: | |
| frames.append(pad_value*np.ones((frame.shape[0], frame.shape[1]//pad_fract_horiz, 3), dtype=dtype)) | |
| elif pad_fract_vert > 0: | |
| frames.append(pad_value*np.ones((frame.shape[0]//pad_fract_vert, frame.shape[1], 3), dtype=dtype)) | |
| frames.append(frame) | |
| return frames | |
| def download_google_drive(url, output_name): | |
| print('Downloading', url) | |
| gdown.download(url, str(output_name)) | |
| # session = requests.Session() | |
| # r = session.get(url, allow_redirects=True) | |
| # r.raise_for_status() | |
| # # Google Drive virus check message | |
| # if r.encoding is not None: | |
| # tokens = re.search('(confirm=.+)&id', str(r.content)) | |
| # assert tokens is not None, 'Could not extract token from response' | |
| # url = url.replace('id=', f'{tokens[1]}&id=') | |
| # r = session.get(url, allow_redirects=True) | |
| # r.raise_for_status() | |
| # assert r.encoding is None, f'Failed to download weight file from {url}' | |
| # with open(output_name, 'wb') as f: | |
| # f.write(r.content) | |
| def download_generic(url, output_name): | |
| print('Downloading', url) | |
| session = requests.Session() | |
| r = session.get(url, allow_redirects=True) | |
| r.raise_for_status() | |
| # No encoding means raw data | |
| if r.encoding is None: | |
| with open(output_name, 'wb') as f: | |
| f.write(r.content) | |
| else: | |
| download_manual(url, output_name) | |
| def download_manual(url, output_name): | |
| outpath = Path(output_name).resolve() | |
| while not outpath.is_file(): | |
| print('Could not find checkpoint') | |
| print(f'Please download the checkpoint from\n{url}\nand save it as\n{outpath}') | |
| input('Press any key to continue...') | |
| def download_ckpt(url, output_name): | |
| if 'drive.google' in url: | |
| download_google_drive(url, output_name) | |
| elif 'mega.nz' in url: | |
| download_manual(url, output_name) | |
| else: | |
| download_generic(url, output_name) |