Spaces:
Sleeping
Sleeping
| """Gradio utilities. | |
| Note that the optional `progress` parameter can be both a `tqdm` module or a | |
| `gr.Progress` instance. | |
| """ | |
| import concurrent.futures | |
| import contextlib | |
| import glob | |
| import hashlib | |
| import logging | |
| import os | |
| import tempfile | |
| import time | |
| import urllib.request | |
| import jax | |
| import numpy as np | |
| from tensorflow.io import gfile | |
| def timed(name): | |
| t0 = time.monotonic() | |
| timing = dict(dt=None) | |
| try: | |
| yield timing | |
| finally: | |
| timing['secs'] = time.monotonic() - t0 | |
| logging.info('Timed %s: %.1f secs', name, timing['secs']) | |
| def copy_file(src, dst, *, progress=None, block_size=1024 * 1024 * 10, overwrite=False): | |
| """Copies a file with progress bar. | |
| Args: | |
| src: Source file (readable by `tf.io.gfile`) or URL. | |
| dst: Destination file. Path must be readable by `tf.io.gfile`. | |
| progress: An object with a `.tqdm` attribute, or `None`. | |
| block_size: Size of individual blocks to be read/written. | |
| """ | |
| if os.path.dirname(dst): | |
| os.makedirs(os.path.dirname(dst), exist_ok=True) | |
| if os.path.exists(dst) and not overwrite: | |
| return | |
| if src.startswith('http://') or src.startswith('https://'): | |
| opener = urllib.request.urlopen | |
| request = urllib.request.Request(src, method='HEAD') | |
| response = urllib.request.urlopen(request) | |
| content_length = response.headers.get('Content-Length') | |
| n = int(np.ceil(int(content_length) / block_size)) | |
| print('content_length', content_length) | |
| else: | |
| opener = lambda path: gfile.GFile(path, 'rb') | |
| stats = gfile.stat(src) | |
| n = int(np.ceil(stats.length / block_size)) | |
| if progress is None: | |
| range_or_trange = range | |
| else: | |
| range_or_trange = lambda n: progress.tqdm(list(range(n)), desc='download') | |
| with opener(src) as fin: | |
| with gfile.GFile(f'{dst}-PARTIAL', 'wb') as fout: | |
| for _ in range_or_trange(n): | |
| fout.write(fin.read(block_size)) | |
| gfile.rename(f'{dst}-PARTIAL', dst) | |
| _estimated_real = [(10, 10)] | |
| _memory_cache = {} | |
| def get_with_progress(getter, secs, progress, step=0.1): | |
| """Returns result from `getter` while showing a progress bar.""" | |
| with concurrent.futures.ThreadPoolExecutor() as executor: | |
| future = executor.submit(getter) | |
| for _ in progress.tqdm(list(range(int(np.ceil(secs/step)))), desc='read'): | |
| if not future.done(): | |
| time.sleep(step) | |
| return future.result() | |
| def _get_array_sizes(tree): | |
| return [getattr(x, 'nbytes', 0) for x in jax.tree_leaves(tree)] | |
| def get_memory_cache(key, getter, max_cache_size_bytes, progress=None, estimated_secs=None): | |
| """Keeps cache below specified size by removing elements not last accessed.""" | |
| if key in _memory_cache: | |
| _memory_cache[key] = _memory_cache.pop(key) # updated "last accessed" order | |
| return _memory_cache[key] | |
| est, real = zip(*_estimated_real) | |
| if estimated_secs is None: | |
| estimated_secs = sum(est) / len(est) | |
| with timed(f'loading {key}') as timing: | |
| estimated_secs *= sum(real) / sum(est) | |
| _memory_cache[key] = get_with_progress(getter, estimated_secs, progress) | |
| _estimated_real.append((estimated_secs, timing['secs'])) | |
| sz = sum(_get_array_sizes(list(_memory_cache.values()))) | |
| logging.info('New memory cache size=%.1f MB', sz/1e6) | |
| while sz > max_cache_size_bytes: | |
| k, v = next(iter(_memory_cache.items())) | |
| if k == key: | |
| break | |
| s = sum(_get_array_sizes(v)) | |
| logging.info('Removing %s from memory cache (%.1f MB)', k, s/1e6) | |
| _memory_cache.pop(k) | |
| sz -= s | |
| return _memory_cache[key] | |
| def get_memory_cache_info(): | |
| """Returns number of items and total size in bytes.""" | |
| sizes = _get_array_sizes(_memory_cache) | |
| return len(_memory_cache), sum(sizes) | |
| CACHE_DIR = os.path.join(tempfile.gettempdir(), 'downloads_cache') | |
| def get_disk_cache(path_or_url, max_cache_size_bytes, progress=None): | |
| """Keeps cache below specified size by removing elements not last accessed.""" | |
| fname = os.path.basename(path_or_url) | |
| path_hash = hashlib.md5(path_or_url.encode()).hexdigest() + '__' + fname | |
| dst = os.path.join(CACHE_DIR, path_hash, fname) | |
| if os.path.exists(dst): | |
| return dst | |
| os.makedirs(os.path.dirname(dst), exist_ok=True) | |
| with timed(f'copying {path_or_url}'): | |
| copy_file(path_or_url, dst, progress=progress) | |
| atimes_sizes_paths = sorted([ | |
| (os.path.getatime(p), os.path.getsize(p), p) | |
| for p in glob.glob(os.path.join(CACHE_DIR, '*', '*')) | |
| if os.path.isfile(p) | |
| ]) | |
| sz = sum(sz for _, sz, _ in atimes_sizes_paths) | |
| logging.info('New disk cache size=%.1f MB', sz/1e6) | |
| while sz > max_cache_size_bytes: | |
| _, s, path = atimes_sizes_paths.pop(0) | |
| if path == dst: | |
| break | |
| logging.info('Removing %s from memory cache (%.1f MB)', fname, sz/1e6) | |
| os.unlink(fname) | |
| sz -= s | |
| return dst | |
| def get_disk_cache_info(): | |
| """Returns number of items and total size in bytes.""" | |
| sizes = [ | |
| os.path.getsize(p) | |
| for p in glob.glob(os.path.join(CACHE_DIR, '*', '*')) | |
| ] | |
| return len(sizes), sum(sizes) | |