Spaces:
Sleeping
Sleeping
Dynamic UI + model loading.
Browse files- .gitignore +2 -0
- README.md +4 -4
- app.py +279 -0
- big_vision_contrastive_models.py +241 -0
- gradio_helpers.py +165 -0
- requirements.txt +12 -0
.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/env
|
| 2 |
+
/__pycache__
|
README.md
CHANGED
|
@@ -1,13 +1,13 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: purple
|
| 5 |
colorTo: yellow
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 4.21.0
|
| 8 |
app_file: app.py
|
| 9 |
-
pinned:
|
| 10 |
license: apache-2.0
|
| 11 |
---
|
| 12 |
|
| 13 |
-
|
|
|
|
| 1 |
---
|
| 2 |
+
title: LiT Demo (big_vision)
|
| 3 |
+
emoji: π
|
| 4 |
colorFrom: purple
|
| 5 |
colorTo: yellow
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 4.21.0
|
| 8 |
app_file: app.py
|
| 9 |
+
pinned: true
|
| 10 |
license: apache-2.0
|
| 11 |
---
|
| 12 |
|
| 13 |
+
Gradio clone of the original [LiT Demo](https://google-research.github.io/vision_transformer/lit/)
|
app.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Gradio clone of https://google-research.github.io/vision_transformer/lit/.
|
| 2 |
+
|
| 3 |
+
Features:
|
| 4 |
+
|
| 5 |
+
- Models are downloaded dynamically.
|
| 6 |
+
- Models are cached on local disk, and in RAM.
|
| 7 |
+
- Progress bars when downloading/reading/computing.
|
| 8 |
+
- Dynamic update of model controls.
|
| 9 |
+
- Dynamic generation of output sliders.
|
| 10 |
+
- Use of `gr.State()` for better use of progress bars.
|
| 11 |
+
"""
|
| 12 |
+
import dataclasses
|
| 13 |
+
import json
|
| 14 |
+
import logging
|
| 15 |
+
import os
|
| 16 |
+
import time
|
| 17 |
+
import urllib.request
|
| 18 |
+
|
| 19 |
+
import gradio as gr
|
| 20 |
+
import PIL.Image
|
| 21 |
+
|
| 22 |
+
import big_vision_contrastive_models as models
|
| 23 |
+
import gradio_helpers
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
INFO_URL = 'https://google-research.github.io/vision_transformer/lit/data/images/info.json'
|
| 27 |
+
IMG_URL_FMT = 'https://google-research.github.io/vision_transformer/lit/data/images/{}.jpg'
|
| 28 |
+
MAX_ANSWERS = 10
|
| 29 |
+
|
| 30 |
+
MAX_DISK_CACHE = 20e9
|
| 31 |
+
MAX_RAM_CACHE = 10e9 # CPU basic has 16G RAM
|
| 32 |
+
|
| 33 |
+
LOADING_SECS = {'B/16': 5, 'L/16': 10, 'So400m/14': 10}
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# family/variant/res -> name
|
| 37 |
+
MODEL_MAP = {
|
| 38 |
+
'lit': {
|
| 39 |
+
'B/16': {
|
| 40 |
+
224: 'lit_b16b',
|
| 41 |
+
},
|
| 42 |
+
'L/16': {
|
| 43 |
+
224: 'lit_l16l',
|
| 44 |
+
},
|
| 45 |
+
},
|
| 46 |
+
'siglip': {
|
| 47 |
+
'B/16': {
|
| 48 |
+
224: 'siglip_b16b_224',
|
| 49 |
+
256: 'siglip_b16b_256',
|
| 50 |
+
384: 'siglip_b16b_384',
|
| 51 |
+
512: 'siglip_b16b_512',
|
| 52 |
+
},
|
| 53 |
+
'L/16': {
|
| 54 |
+
256: 'siglip_l16l_256',
|
| 55 |
+
384: 'siglip_l16l_384',
|
| 56 |
+
},
|
| 57 |
+
'So400m/14': {
|
| 58 |
+
224: 'siglip_so400m14so440m_224',
|
| 59 |
+
384: 'siglip_so400m14so440m_384',
|
| 60 |
+
},
|
| 61 |
+
},
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def compute(image_path, prompts, family, variant, res, bias, progress=gr.Progress()):
|
| 66 |
+
"""Loads model and computes answers."""
|
| 67 |
+
|
| 68 |
+
if image_path is None:
|
| 69 |
+
raise gr.Error('Must first select an image!')
|
| 70 |
+
|
| 71 |
+
t0 = time.monotonic()
|
| 72 |
+
|
| 73 |
+
model_name = MODEL_MAP[family][variant][res]
|
| 74 |
+
config = models.MODEL_CONFIGS[model_name]
|
| 75 |
+
local_ckpt = gradio_helpers.get_disk_cache(
|
| 76 |
+
config.ckpt, progress=progress, max_cache_size_bytes=MAX_DISK_CACHE)
|
| 77 |
+
config = dataclasses.replace(config, ckpt=local_ckpt)
|
| 78 |
+
params, model = gradio_helpers.get_memory_cache(
|
| 79 |
+
config,
|
| 80 |
+
lambda: models.load_model(config),
|
| 81 |
+
max_cache_size_bytes=MAX_RAM_CACHE,
|
| 82 |
+
progress=progress,
|
| 83 |
+
estimated_secs={
|
| 84 |
+
('lit', 'B/16'): 1,
|
| 85 |
+
('lit', 'L/16'): 2.5,
|
| 86 |
+
('siglip', 'B/16'): 9,
|
| 87 |
+
('siglip', 'L/16'): 28,
|
| 88 |
+
('siglip', 'So400m/14'): 36,
|
| 89 |
+
}.get((family, variant))
|
| 90 |
+
)
|
| 91 |
+
model: models.ContrastiveModel = model
|
| 92 |
+
|
| 93 |
+
it = progress.tqdm(list(range(3)), desc='compute')
|
| 94 |
+
|
| 95 |
+
logging.info('Opening image "%s"', image_path)
|
| 96 |
+
with gradio_helpers.timed(f'opening image "{image_path}"'):
|
| 97 |
+
image = PIL.Image.open(image_path)
|
| 98 |
+
next(it)
|
| 99 |
+
with gradio_helpers.timed('image features'):
|
| 100 |
+
zimg, out = model.embed_images(
|
| 101 |
+
params, model.preprocess_images([image])
|
| 102 |
+
)
|
| 103 |
+
next(it)
|
| 104 |
+
with gradio_helpers.timed('text features'):
|
| 105 |
+
prompts = prompts.split('\n')
|
| 106 |
+
ztxt, out = model.embed_texts(
|
| 107 |
+
params, model.preprocess_texts(prompts)
|
| 108 |
+
)
|
| 109 |
+
next(it)
|
| 110 |
+
|
| 111 |
+
t = model.get_temperature(out)
|
| 112 |
+
if family == 'lit':
|
| 113 |
+
text_probs = list(model.get_probabilities(zimg, ztxt, t, axis=-1)[0])
|
| 114 |
+
elif family == 'siglip':
|
| 115 |
+
text_probs = list(model.get_probabilities(zimg, ztxt, t, bias=bias)[0])
|
| 116 |
+
|
| 117 |
+
state = list(zip(prompts, [round(p.item(), 3) for p in text_probs]))
|
| 118 |
+
|
| 119 |
+
dt = time.monotonic() - t0
|
| 120 |
+
mem_n, mem_sz = gradio_helpers.get_memory_cache_info()
|
| 121 |
+
disk_n, disk_sz = gradio_helpers.get_disk_cache_info()
|
| 122 |
+
status = gr.Markdown(
|
| 123 |
+
f'Computed inference in {dt:.1f} seconds ('
|
| 124 |
+
f'memory cache {mem_n} items, {mem_sz/1e6:.1f} M, '
|
| 125 |
+
f'disk cache {disk_n} items, {disk_sz/1e6:.1f} M)')
|
| 126 |
+
|
| 127 |
+
if 'b' in out:
|
| 128 |
+
logging.info('model_name=%s default bias=%f', model_name, out['b'])
|
| 129 |
+
|
| 130 |
+
return status, state
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def update_answers(state):
|
| 134 |
+
"""Generates visible sliders for answers."""
|
| 135 |
+
answers = []
|
| 136 |
+
for prompt, prob in state[:MAX_ANSWERS]:
|
| 137 |
+
answers.append(gr.Slider(value=round(100*prob, 2), label=prompt, visible=True))
|
| 138 |
+
while len(answers) < MAX_ANSWERS:
|
| 139 |
+
answers.append(gr.Slider(visible=False))
|
| 140 |
+
return answers
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def create_app():
|
| 144 |
+
"""Creates demo UI."""
|
| 145 |
+
|
| 146 |
+
css = '''
|
| 147 |
+
.slider input[type="number"] { width: 5em; }
|
| 148 |
+
#examples td.textbox > div {
|
| 149 |
+
white-space: pre-wrap !important;
|
| 150 |
+
text-align: left;
|
| 151 |
+
}
|
| 152 |
+
'''
|
| 153 |
+
|
| 154 |
+
with gr.Blocks(css=css) as demo:
|
| 155 |
+
|
| 156 |
+
gr.Markdown('Gradio clone of the original [LiT demo](https://google-research.github.io/vision_transformer/lit/).')
|
| 157 |
+
|
| 158 |
+
status = gr.Markdown()
|
| 159 |
+
|
| 160 |
+
with gr.Row():
|
| 161 |
+
image = gr.Image(label='Image', type='filepath')
|
| 162 |
+
source = gr.Markdown('', visible=False)
|
| 163 |
+
state = gr.State([])
|
| 164 |
+
with gr.Column():
|
| 165 |
+
prompts = gr.Textbox(label='Prompts (press Shift-ENTER to add a prompt)')
|
| 166 |
+
with gr.Row():
|
| 167 |
+
|
| 168 |
+
values = {}
|
| 169 |
+
|
| 170 |
+
family = gr.Dropdown(value='lit', choices=list(MODEL_MAP), label='Model family')
|
| 171 |
+
values['family'] = family.value
|
| 172 |
+
|
| 173 |
+
# Unfortunately below reactive UI code is a bit convoluted, because:
|
| 174 |
+
# 1. When e.g. `family.change()` updates `variant`, then that does not
|
| 175 |
+
# trigger a `varaint.change()`.
|
| 176 |
+
# 2. The widget values like `family.value` are *not* updated when the
|
| 177 |
+
# widget is updated. Therefore, we keep a manual copy in `values`.
|
| 178 |
+
|
| 179 |
+
def make_variant(family_value):
|
| 180 |
+
choices = list(MODEL_MAP[family_value])
|
| 181 |
+
values['variant'] = choices[0]
|
| 182 |
+
return gr.Dropdown(value=values['variant'], choices=choices, label='Variant')
|
| 183 |
+
variant = make_variant(family.value)
|
| 184 |
+
|
| 185 |
+
def make_res(family, variant):
|
| 186 |
+
choices = list(MODEL_MAP[family][variant])
|
| 187 |
+
values['res'] = choices[0]
|
| 188 |
+
return gr.Dropdown(value=values['res'], choices=choices, label='Resolution')
|
| 189 |
+
res = make_res(family.value, variant.value)
|
| 190 |
+
values['res'] = res.value
|
| 191 |
+
|
| 192 |
+
def make_bias(family, variant, res):
|
| 193 |
+
visible = family == 'siglip'
|
| 194 |
+
value = {
|
| 195 |
+
('siglip', 'B/16', 224): -12.9,
|
| 196 |
+
('siglip', 'L/16', 256): -12.7,
|
| 197 |
+
('siglip', 'L/16', 256): -16.5,
|
| 198 |
+
# ...
|
| 199 |
+
}.get((family, variant, res), -10.0)
|
| 200 |
+
return gr.Slider(value=value, minimum=-20, maximum=0, step=0.05, label='Bias', visible=visible)
|
| 201 |
+
bias = make_bias(family.value, variant.value, res.value)
|
| 202 |
+
values['bias'] = bias.value
|
| 203 |
+
|
| 204 |
+
def family_changed(family):
|
| 205 |
+
variant = list(MODEL_MAP[family])[0]
|
| 206 |
+
res = list(MODEL_MAP[family][variant])[0]
|
| 207 |
+
values['family'] = family
|
| 208 |
+
values['variant'] = variant
|
| 209 |
+
values['res'] = res
|
| 210 |
+
return [
|
| 211 |
+
make_variant(family),
|
| 212 |
+
make_res(family, variant),
|
| 213 |
+
make_bias(family, variant, res),
|
| 214 |
+
]
|
| 215 |
+
|
| 216 |
+
def variant_changed(variant):
|
| 217 |
+
res = list(MODEL_MAP[values['family']][variant])[0]
|
| 218 |
+
values['variant'] = variant
|
| 219 |
+
values['res'] = res
|
| 220 |
+
return [
|
| 221 |
+
make_res(values['family'], variant),
|
| 222 |
+
make_bias(values['family'], variant, res),
|
| 223 |
+
]
|
| 224 |
+
|
| 225 |
+
def res_changed(res):
|
| 226 |
+
return make_bias(values['family'], values['variant'], res)
|
| 227 |
+
|
| 228 |
+
family.change(family_changed, family, [variant, res, bias])
|
| 229 |
+
variant.change(variant_changed, variant, [res, bias])
|
| 230 |
+
res.change(res_changed, res, bias)
|
| 231 |
+
|
| 232 |
+
# (end of code for reactive UI code)
|
| 233 |
+
|
| 234 |
+
run = gr.Button('Run')
|
| 235 |
+
answers = [
|
| 236 |
+
# Will be set to visible in `update_answers()`.
|
| 237 |
+
gr.Slider(0, 100, 0, visible=False, elem_classes='slider')
|
| 238 |
+
for _ in range(MAX_ANSWERS)
|
| 239 |
+
]
|
| 240 |
+
|
| 241 |
+
# We want to avoid showing multiple progress bars, so we only update
|
| 242 |
+
# a single `status` widget here, and store the computed information in
|
| 243 |
+
# `state`...
|
| 244 |
+
run.click(
|
| 245 |
+
fn=compute, inputs=[image, prompts, family, variant, res, bias], outputs=[status, state])
|
| 246 |
+
# ... then we use `state` to update UI components without showing a
|
| 247 |
+
# progress bar in their place.
|
| 248 |
+
status.change(fn=update_answers, inputs=state, outputs=answers)
|
| 249 |
+
|
| 250 |
+
info = json.load(urllib.request.urlopen(INFO_URL))
|
| 251 |
+
gr.Markdown('Note: below images have 224 px resolution only:')
|
| 252 |
+
gr.Examples(
|
| 253 |
+
examples=[
|
| 254 |
+
[
|
| 255 |
+
IMG_URL_FMT.format(ex['id']),
|
| 256 |
+
ex['prompts'].replace(', ', '\n'),
|
| 257 |
+
'[source](%s)' % ex['source'],
|
| 258 |
+
]
|
| 259 |
+
for ex in info
|
| 260 |
+
],
|
| 261 |
+
inputs=[image, prompts, source, license],
|
| 262 |
+
outputs=answers,
|
| 263 |
+
elem_id='examples',
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
return demo
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
if __name__ == "__main__":
|
| 270 |
+
|
| 271 |
+
logging.basicConfig(level=logging.INFO,
|
| 272 |
+
format='%(asctime)s - %(levelname)s - %(message)s')
|
| 273 |
+
|
| 274 |
+
for k, v in os.environ.items():
|
| 275 |
+
logging.info('environ["%s"] = %r', k, v)
|
| 276 |
+
|
| 277 |
+
models.setup()
|
| 278 |
+
|
| 279 |
+
create_app().queue().launch()
|
big_vision_contrastive_models.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Wrapper for big_vision contrastive models.
|
| 2 |
+
|
| 3 |
+
Before using any of the functions, make sure to call `setup()`.
|
| 4 |
+
|
| 5 |
+
Choose one of the configs in `MODEL_CONFIGS` and then call `load_model()` to get
|
| 6 |
+
the params and model wrapper.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import dataclasses
|
| 10 |
+
import enum
|
| 11 |
+
import functools
|
| 12 |
+
import importlib
|
| 13 |
+
import os
|
| 14 |
+
import subprocess
|
| 15 |
+
import sys
|
| 16 |
+
import tempfile
|
| 17 |
+
|
| 18 |
+
import flax.linen as nn
|
| 19 |
+
import jax
|
| 20 |
+
import jax.numpy as jnp
|
| 21 |
+
import ml_collections
|
| 22 |
+
import numpy as np
|
| 23 |
+
import PIL.Image
|
| 24 |
+
import sentencepiece
|
| 25 |
+
from tensorflow.io import gfile
|
| 26 |
+
import transformers
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _clone_git(url, destination_folder, commit_hash=None):
|
| 30 |
+
subprocess.run([
|
| 31 |
+
'git', 'clone', '--depth=1',
|
| 32 |
+
url, destination_folder
|
| 33 |
+
], check=True)
|
| 34 |
+
if commit_hash:
|
| 35 |
+
subprocess.run(['git', '-C', destination_folder, 'checkout', commit_hash], check=True)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def setup(commit_hash=None):
|
| 39 |
+
for url, dst_name in (
|
| 40 |
+
('https://github.com/google-research/big_vision', 'big_vision_repo'),
|
| 41 |
+
('https://github.com/google/flaxformer', 'flaxformer_repo'),
|
| 42 |
+
):
|
| 43 |
+
dst_path = os.path.join(tempfile.gettempdir(), dst_name)
|
| 44 |
+
if not os.path.exists(dst_path):
|
| 45 |
+
_clone_git(url, dst_path, commit_hash)
|
| 46 |
+
if not dst_path in sys.path:
|
| 47 |
+
sys.path.insert(0, dst_path)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class ContrastiveModelFamily(enum.Enum):
|
| 51 |
+
LIT = 'lit'
|
| 52 |
+
SIGLIP = 'siglip'
|
| 53 |
+
|
| 54 |
+
@property
|
| 55 |
+
def paper(self):
|
| 56 |
+
return {
|
| 57 |
+
self.LIT: 'https://arxiv.org/abs/2111.07991',
|
| 58 |
+
self.SIGLIP: 'https://arxiv.org/abs/2303.15343',
|
| 59 |
+
}[self]
|
| 60 |
+
|
| 61 |
+
def __lt__(self, other):
|
| 62 |
+
return self.value < other.value
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
@dataclasses.dataclass(frozen=True, kw_only=True, order=True)
|
| 66 |
+
class ContrastiveModelConfig:
|
| 67 |
+
"""Desribes a `big_vision` contrastive model."""
|
| 68 |
+
family: ContrastiveModelFamily
|
| 69 |
+
variant: str
|
| 70 |
+
res: int
|
| 71 |
+
textvariant: str
|
| 72 |
+
embdim: int
|
| 73 |
+
seqlen: int
|
| 74 |
+
tokenizer: str
|
| 75 |
+
vocab_size: int
|
| 76 |
+
ckpt: str
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@dataclasses.dataclass(frozen=True, kw_only=True)
|
| 80 |
+
class ContrastiveModel:
|
| 81 |
+
"""Wraps a `big_vision` contrastive model."""
|
| 82 |
+
|
| 83 |
+
config: ContrastiveModelConfig
|
| 84 |
+
flax_module: nn.Module
|
| 85 |
+
tokenizer_sp: sentencepiece.SentencePieceProcessor | None
|
| 86 |
+
tokenizer_bert: transformers.BertTokenizer | None
|
| 87 |
+
|
| 88 |
+
def embed_images(self, params, images):
|
| 89 |
+
assert getattr(images, 'ndim') == 4, 'Must call `.preprocess_images()`'
|
| 90 |
+
zimg, _, out = self.flax_module.apply(dict(params=params), images, None)
|
| 91 |
+
return zimg, out
|
| 92 |
+
|
| 93 |
+
def embed_texts(self, params, texts):
|
| 94 |
+
assert getattr(texts, 'ndim') == 2, 'Must call `.preprocess_texts()`'
|
| 95 |
+
_, ztxt, out = self.flax_module.apply(dict(params=params), None, texts)
|
| 96 |
+
return ztxt, out
|
| 97 |
+
|
| 98 |
+
def preprocess_texts(self, texts):
|
| 99 |
+
|
| 100 |
+
def tokenize_pad(text, seqlen=self.config.seqlen):
|
| 101 |
+
|
| 102 |
+
if self.config.family == ContrastiveModelFamily.LIT:
|
| 103 |
+
tokens = self.tokenizer_bert.encode(text, add_special_tokens=True)[:-1] # removes [SEP]
|
| 104 |
+
tokens = tokens[:seqlen]
|
| 105 |
+
return tokens + [0] * (seqlen - len(tokens))
|
| 106 |
+
|
| 107 |
+
if self.config.family == ContrastiveModelFamily.SIGLIP:
|
| 108 |
+
tokens = self.tokenizer_sp.tokenize(text, add_eos=True)
|
| 109 |
+
if len(tokens) >= seqlen:
|
| 110 |
+
return tokens[:seqlen - 1] + [tok.eos_id()] # "sticky" eos
|
| 111 |
+
return tokens + [0] * (seqlen - len(tokens))
|
| 112 |
+
|
| 113 |
+
return np.array([tokenize_pad(text) for text in texts])
|
| 114 |
+
|
| 115 |
+
def preprocess_images(self, images):
|
| 116 |
+
if not isinstance(images, (list, tuple)):
|
| 117 |
+
images = [images]
|
| 118 |
+
def topil(image):
|
| 119 |
+
if not isinstance(image, PIL.Image.Image):
|
| 120 |
+
image = PIL.Image.fromarray(image)
|
| 121 |
+
return image
|
| 122 |
+
return np.array([
|
| 123 |
+
topil(image).resize([self.config.res, self.config.res])
|
| 124 |
+
for image in images
|
| 125 |
+
]) / 127.5 - 1.0
|
| 126 |
+
|
| 127 |
+
def get_bias(self, out):
|
| 128 |
+
assert self.config.family == ContrastiveModelFamily.SIGLIP, self.config.family
|
| 129 |
+
return out['b'].item()
|
| 130 |
+
|
| 131 |
+
def get_temperature(self, out):
|
| 132 |
+
return out['t'].item()
|
| 133 |
+
|
| 134 |
+
def get_probabilities(self, zimg, ztxt, temperature, *, axis=None, bias=None):
|
| 135 |
+
# Note: zimg, ztxt are already normalized.
|
| 136 |
+
|
| 137 |
+
if self.config.family == ContrastiveModelFamily.LIT:
|
| 138 |
+
assert bias is None
|
| 139 |
+
assert axis in (-1, -2), 'Must specify axis: -1/-2=normalize texts/images'
|
| 140 |
+
return jax.nn.softmax(zimg @ ztxt.T * temperature, axis=axis)
|
| 141 |
+
|
| 142 |
+
if self.config.family == ContrastiveModelFamily.SIGLIP:
|
| 143 |
+
assert axis is None
|
| 144 |
+
assert bias is not None, 'Must specify bias.'
|
| 145 |
+
return jax.nn.sigmoid(zimg @ ztxt.T * temperature + bias)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def _make_config(family, variant, res, textvariant, ckpt, embdim, seqlen, vocab_size):
|
| 149 |
+
if family == 'lit':
|
| 150 |
+
tokenizer = ckpt.replace('.npz', '.txt')
|
| 151 |
+
else:
|
| 152 |
+
tokenizer = 'c4_en'
|
| 153 |
+
return ContrastiveModelConfig(
|
| 154 |
+
family=ContrastiveModelFamily(family), variant=variant, res=res,
|
| 155 |
+
textvariant=textvariant, embdim=embdim, seqlen=seqlen,
|
| 156 |
+
tokenizer=tokenizer, vocab_size=32_000,
|
| 157 |
+
ckpt=ckpt,
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
MODEL_CONFIGS = dict(
|
| 162 |
+
lit_b16b=_make_config('lit', 'B/16', 224, 'B', 'gs://vit_models/lit/LiT-B16B.npz', 768, 16, 32_000),
|
| 163 |
+
lit_l16l=_make_config('lit', 'L/16', 224, 'L', 'gs://vit_models/lit/LiT-L16L.npz', 1024, 16, 32_000),
|
| 164 |
+
lit_b16s=_make_config('lit', 'L/16', 224, 'S', 'gs://vit_models/lit/LiT-L16S.npz', 1024, 16, 32_000),
|
| 165 |
+
lit_b16ti=_make_config('lit', 'L/16', 224, 'Ti', 'gs://vit_models/lit/LiT-L16Ti.npz', 1024, 16, 32_000),
|
| 166 |
+
|
| 167 |
+
siglip_b16b_224=_make_config('siglip', 'B/16', 224, 'B', 'gs://big_vision/siglip/webli_en_b16_224_63724782.npz', 768, 64, 32_000),
|
| 168 |
+
siglip_b16b_256=_make_config('siglip', 'B/16', 256, 'B', 'gs://big_vision/siglip/webli_en_b16_256_60500360.npz', 768, 64, 32_000),
|
| 169 |
+
siglip_b16b_384=_make_config('siglip', 'B/16', 384, 'B', 'gs://big_vision/siglip/webli_en_b16_384_68578854.npz', 768, 64, 32_000),
|
| 170 |
+
siglip_b16b_512=_make_config('siglip', 'B/16', 512, 'B', 'gs://big_vision/siglip/webli_en_b16_512_68580893.npz', 768, 64, 32_000),
|
| 171 |
+
siglip_l16l_256=_make_config('siglip', 'L/16', 256, 'L', 'gs://big_vision/siglip/webli_en_l16_256_60552751.npz', 1024, 64, 32_000),
|
| 172 |
+
siglip_l16l_384=_make_config('siglip', 'L/16', 384, 'L', 'gs://big_vision/siglip/webli_en_l16_384_63634585.npz', 1024, 64, 32_000),
|
| 173 |
+
siglip_so400m14so440m_224=_make_config('siglip', 'So400m/14', 224, 'So400m', 'gs://big_vision/siglip/webli_en_so400m_224_57633886.npz', 1152, 16, 32_000),
|
| 174 |
+
siglip_so400m14so400m_384=_make_config('siglip', 'So400m/14', 384, 'So400m', 'gs://big_vision/siglip/webli_en_so400m_384_58765454.npz', 1152, 64, 32_000),
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
@functools.cache
|
| 179 |
+
def load_tokenizer_sp(name_or_path):
|
| 180 |
+
tok = sentencepiece.SentencePieceProcessor()
|
| 181 |
+
path = {
|
| 182 |
+
'c4_en': 'gs://t5-data/vocabs/cc_en.32000/sentencepiece.model',
|
| 183 |
+
}.get(name_or_path, name_or_path)
|
| 184 |
+
tok.LoadFromSerializedProto(gfile.GFile(path, 'rb').read())
|
| 185 |
+
return tok
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
@functools.cache
|
| 189 |
+
def load_tokenizer_bert(path):
|
| 190 |
+
tok = sentencepiece.SentencePieceProcessor()
|
| 191 |
+
if path.startswith('gs://'):
|
| 192 |
+
dst = tempfile.mktemp()
|
| 193 |
+
gfile.copy(path, dst)
|
| 194 |
+
path = dst
|
| 195 |
+
return transformers.BertTokenizer(path, do_lower_case=True)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def load_model(config, check_params=False):
|
| 199 |
+
"""Loads `big_vision` model."""
|
| 200 |
+
assert isinstance(config, ContrastiveModelConfig), type(config)
|
| 201 |
+
|
| 202 |
+
cfg = ml_collections.ConfigDict()
|
| 203 |
+
cfg.image_model = 'vit' # TODO(lbeyer): remove later, default
|
| 204 |
+
if config.family == ContrastiveModelFamily.LIT:
|
| 205 |
+
cfg.text_model = 'proj.flaxformer.bert'
|
| 206 |
+
cfg.image = dict(variant=config.variant, pool_type='tok', head_zeroinit=False)
|
| 207 |
+
bert_config = {'B': 'base', 'L': 'large'}[config.textvariant]
|
| 208 |
+
cfg.text = dict(config=bert_config, head_zeroinit=False)
|
| 209 |
+
tokenizer_bert = load_tokenizer_bert(config.tokenizer)
|
| 210 |
+
tokenizer_sp = None
|
| 211 |
+
if config.variant == 'L/16':
|
| 212 |
+
cfg.out_dim = (None, config.embdim) # (image_out_dim, text_out_dim)
|
| 213 |
+
else:
|
| 214 |
+
cfg.out_dim = (config.embdim, config.embdim) # (image_out_dim, text_out_dim)
|
| 215 |
+
else:
|
| 216 |
+
cfg.image = dict(variant=config.variant, pool_type='map')
|
| 217 |
+
cfg.text_model = 'proj.image_text.text_transformer' # TODO(lbeyer): remove later, default
|
| 218 |
+
cfg.text = dict(variant=config.textvariant, vocab_size=config.vocab_size)
|
| 219 |
+
cfg.bias_init = -10.0
|
| 220 |
+
tokenizer_sp = load_tokenizer_sp(config.tokenizer)
|
| 221 |
+
tokenizer_bert = None
|
| 222 |
+
cfg.out_dim = (None, config.embdim) # (image_out_dim, text_out_dim)
|
| 223 |
+
cfg.temperature_init = 10.0
|
| 224 |
+
|
| 225 |
+
model_mod = importlib.import_module(
|
| 226 |
+
'big_vision.models.proj.image_text.two_towers')
|
| 227 |
+
model = model_mod.Model(**cfg)
|
| 228 |
+
|
| 229 |
+
init_params = None # Faster but bypasses loading sanity-checks.
|
| 230 |
+
if check_params:
|
| 231 |
+
imgs = jnp.zeros([1, config.res, config.res, 3])
|
| 232 |
+
txts = jnp.zeros([1, config.seqlen], jnp.int32)
|
| 233 |
+
init_params = model.init(jax.random.PRNGKey(0), imgs, txts)['params']
|
| 234 |
+
params_cpu = model_mod.load(init_params, config.ckpt, cfg)
|
| 235 |
+
|
| 236 |
+
return params_cpu, ContrastiveModel(
|
| 237 |
+
config=config,
|
| 238 |
+
flax_module=model,
|
| 239 |
+
tokenizer_sp=tokenizer_sp,
|
| 240 |
+
tokenizer_bert=tokenizer_bert,
|
| 241 |
+
)
|
gradio_helpers.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Gradio utilities.
|
| 2 |
+
|
| 3 |
+
Note that the optional `progress` parameter can be both a `tqdm` module or a
|
| 4 |
+
`gr.Progress` instance.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import concurrent.futures
|
| 8 |
+
import contextlib
|
| 9 |
+
import glob
|
| 10 |
+
import hashlib
|
| 11 |
+
import logging
|
| 12 |
+
import os
|
| 13 |
+
import tempfile
|
| 14 |
+
import time
|
| 15 |
+
import urllib.request
|
| 16 |
+
|
| 17 |
+
import jax
|
| 18 |
+
import numpy as np
|
| 19 |
+
from tensorflow.io import gfile
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@contextlib.contextmanager
|
| 23 |
+
def timed(name):
|
| 24 |
+
t0 = time.monotonic()
|
| 25 |
+
timing = dict(dt=None)
|
| 26 |
+
try:
|
| 27 |
+
yield timing
|
| 28 |
+
finally:
|
| 29 |
+
timing['secs'] = time.monotonic() - t0
|
| 30 |
+
logging.info('Timed %s: %.1f secs', name, timing['secs'])
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def copy_file(src, dst, *, progress=None, block_size=1024 * 1024 * 10, overwrite=False):
|
| 35 |
+
"""Copies a file with progress bar.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
src: Source file (readable by `tf.io.gfile`) or URL.
|
| 39 |
+
dst: Destination file. Path must be readable by `tf.io.gfile`.
|
| 40 |
+
progress: An object with a `.tqdm` attribute, or `None`.
|
| 41 |
+
block_size: Size of individual blocks to be read/written.
|
| 42 |
+
"""
|
| 43 |
+
if os.path.dirname(dst):
|
| 44 |
+
os.makedirs(os.path.dirname(dst), exist_ok=True)
|
| 45 |
+
if os.path.exists(dst) and not overwrite:
|
| 46 |
+
return
|
| 47 |
+
|
| 48 |
+
if src.startswith('http://') or src.startswith('https://'):
|
| 49 |
+
opener = urllib.request.urlopen
|
| 50 |
+
request = urllib.request.Request(src, method='HEAD')
|
| 51 |
+
response = urllib.request.urlopen(request)
|
| 52 |
+
content_length = response.headers.get('Content-Length')
|
| 53 |
+
n = int(np.ceil(int(content_length) / block_size))
|
| 54 |
+
print('content_length', content_length)
|
| 55 |
+
else:
|
| 56 |
+
opener = lambda path: gfile.GFile(path, 'rb')
|
| 57 |
+
stats = gfile.stat(src)
|
| 58 |
+
n = int(np.ceil(stats.length / block_size))
|
| 59 |
+
|
| 60 |
+
if progress is None:
|
| 61 |
+
range_or_trange = range
|
| 62 |
+
else:
|
| 63 |
+
range_or_trange = lambda n: progress.tqdm(list(range(n)), desc='download')
|
| 64 |
+
|
| 65 |
+
with opener(src) as fin:
|
| 66 |
+
with gfile.GFile(f'{dst}-PARTIAL', 'wb') as fout:
|
| 67 |
+
for _ in range_or_trange(n):
|
| 68 |
+
fout.write(fin.read(block_size))
|
| 69 |
+
gfile.rename(f'{dst}-PARTIAL', dst)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
_estimated_real = [(10, 10)]
|
| 73 |
+
_memory_cache = {}
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def get_with_progress(getter, secs, progress, step=0.1):
|
| 77 |
+
"""Returns result from `getter` while showing a progress bar."""
|
| 78 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
| 79 |
+
future = executor.submit(getter)
|
| 80 |
+
for _ in progress.tqdm(list(range(int(np.ceil(secs/step)))), desc='read'):
|
| 81 |
+
if not future.done():
|
| 82 |
+
time.sleep(step)
|
| 83 |
+
return future.result()
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def _get_array_sizes(tree):
|
| 87 |
+
return [getattr(x, 'nbytes', 0) for x in jax.tree_leaves(tree)]
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def get_memory_cache(key, getter, max_cache_size_bytes, progress=None, estimated_secs=None):
|
| 91 |
+
"""Keeps cache below specified size by removing elements not last accessed."""
|
| 92 |
+
if key in _memory_cache:
|
| 93 |
+
_memory_cache[key] = _memory_cache.pop(key) # updated "last accessed" order
|
| 94 |
+
return _memory_cache[key]
|
| 95 |
+
|
| 96 |
+
est, real = zip(*_estimated_real)
|
| 97 |
+
if estimated_secs is None:
|
| 98 |
+
estimated_secs = sum(est) / len(est)
|
| 99 |
+
with timed(f'loading {key}') as timing:
|
| 100 |
+
estimated_secs *= sum(real) / sum(est)
|
| 101 |
+
_memory_cache[key] = get_with_progress(getter, estimated_secs, progress)
|
| 102 |
+
_estimated_real.append((estimated_secs, timing['secs']))
|
| 103 |
+
|
| 104 |
+
sz = sum(_get_array_sizes(list(_memory_cache.values())))
|
| 105 |
+
logging.info('New memory cache size=%.1f MB', sz/1e6)
|
| 106 |
+
|
| 107 |
+
while sz > max_cache_size_bytes:
|
| 108 |
+
k, v = next(iter(_memory_cache.items()))
|
| 109 |
+
if k == key:
|
| 110 |
+
break
|
| 111 |
+
s = sum(_get_array_sizes(v))
|
| 112 |
+
logging.info('Removing %s from memory cache (%.1f MB)', k, s/1e6)
|
| 113 |
+
_memory_cache.pop(k)
|
| 114 |
+
sz -= s
|
| 115 |
+
|
| 116 |
+
return _memory_cache[key]
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def get_memory_cache_info():
|
| 120 |
+
"""Returns number of items and total size in bytes."""
|
| 121 |
+
sizes = _get_array_sizes(_memory_cache)
|
| 122 |
+
return len(_memory_cache), sum(sizes)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
CACHE_DIR = os.path.join(tempfile.gettempdir(), 'downloads_cache')
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def get_disk_cache(path_or_url, max_cache_size_bytes, progress=None):
|
| 129 |
+
"""Keeps cache below specified size by removing elements not last accessed."""
|
| 130 |
+
fname = os.path.basename(path_or_url)
|
| 131 |
+
path_hash = hashlib.md5(path_or_url.encode()).hexdigest() + '__' + fname
|
| 132 |
+
dst = os.path.join(CACHE_DIR, path_hash, fname)
|
| 133 |
+
if os.path.exists(dst):
|
| 134 |
+
return dst
|
| 135 |
+
|
| 136 |
+
os.makedirs(os.path.dirname(dst), exist_ok=True)
|
| 137 |
+
with timed(f'copying {path_or_url}'):
|
| 138 |
+
copy_file(path_or_url, dst, progress=progress)
|
| 139 |
+
|
| 140 |
+
atimes_sizes_paths = sorted([
|
| 141 |
+
(os.path.getatime(p), os.path.getsize(p), p)
|
| 142 |
+
for p in glob.glob(os.path.join(CACHE_DIR, '*', '*'))
|
| 143 |
+
if os.path.isfile(p)
|
| 144 |
+
])
|
| 145 |
+
sz = sum(sz for _, sz, _ in atimes_sizes_paths)
|
| 146 |
+
logging.info('New disk cache size=%.1f MB', sz/1e6)
|
| 147 |
+
|
| 148 |
+
while sz > max_cache_size_bytes:
|
| 149 |
+
_, s, path = atimes_sizes_paths.pop(0)
|
| 150 |
+
if path == dst:
|
| 151 |
+
break
|
| 152 |
+
logging.info('Removing %s from memory cache (%.1f MB)', fname, sz/1e6)
|
| 153 |
+
os.unlink(fname)
|
| 154 |
+
sz -= s
|
| 155 |
+
|
| 156 |
+
return dst
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def get_disk_cache_info():
|
| 160 |
+
"""Returns number of items and total size in bytes."""
|
| 161 |
+
sizes = [
|
| 162 |
+
os.path.getsize(p)
|
| 163 |
+
for p in glob.glob(os.path.join(CACHE_DIR, '*', '*'))
|
| 164 |
+
]
|
| 165 |
+
return len(sizes), sum(sizes)
|
requirements.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
aqtp # for flaxformer
|
| 2 |
+
einops
|
| 3 |
+
flax
|
| 4 |
+
gradio
|
| 5 |
+
jax
|
| 6 |
+
jaxlib
|
| 7 |
+
ml_collections
|
| 8 |
+
numpy
|
| 9 |
+
Pillow
|
| 10 |
+
sentencepiece
|
| 11 |
+
transformers # for transformers.BertTokenizer
|
| 12 |
+
tensorflow # for tf.io.gfile
|