Spaces:
Sleeping
Sleeping
| """Gradio clone of https://google-research.github.io/vision_transformer/lit/. | |
| Features: | |
| - Models are downloaded dynamically. | |
| - Models are cached on local disk, and in RAM. | |
| - Progress bars when downloading/reading/computing. | |
| - Dynamic update of model controls. | |
| - Dynamic generation of output sliders. | |
| - Use of `gr.State()` for better use of progress bars. | |
| """ | |
| import dataclasses | |
| import json | |
| import logging | |
| import os | |
| import time | |
| import urllib.request | |
| import gradio as gr | |
| import PIL.Image | |
| import big_vision_contrastive_models as models | |
| import gradio_helpers | |
| INFO_URL = 'https://google-research.github.io/vision_transformer/lit/data/images/info.json' | |
| IMG_URL_FMT = 'https://google-research.github.io/vision_transformer/lit/data/images/{}.jpg' | |
| MAX_ANSWERS = 10 | |
| MAX_DISK_CACHE = 20e9 | |
| MAX_RAM_CACHE = 10e9 # CPU basic has 16G RAM | |
| LOADING_SECS = {'B/16': 5, 'L/16': 10, 'So400m/14': 10} | |
| # family/variant/res -> name | |
| MODEL_MAP = { | |
| 'lit': { | |
| 'B/16': { | |
| 224: 'lit_b16b', | |
| }, | |
| 'L/16': { | |
| 224: 'lit_l16l', | |
| }, | |
| }, | |
| 'siglip': { | |
| 'B/16': { | |
| 224: 'siglip_b16b_224', | |
| 256: 'siglip_b16b_256', | |
| 384: 'siglip_b16b_384', | |
| 512: 'siglip_b16b_512', | |
| }, | |
| 'L/16': { | |
| 256: 'siglip_l16l_256', | |
| 384: 'siglip_l16l_384', | |
| }, | |
| 'So400m/14': { | |
| 224: 'siglip_so400m14so440m_224', | |
| 384: 'siglip_so400m14so440m_384', | |
| }, | |
| }, | |
| } | |
| def compute(image_path, prompts, family, variant, res, bias, progress=gr.Progress()): | |
| """Loads model and computes answers.""" | |
| if image_path is None: | |
| raise gr.Error('Must first select an image!') | |
| t0 = time.monotonic() | |
| model_name = MODEL_MAP[family][variant][res] | |
| config = models.MODEL_CONFIGS[model_name] | |
| local_ckpt = gradio_helpers.get_disk_cache( | |
| config.ckpt, progress=progress, max_cache_size_bytes=MAX_DISK_CACHE) | |
| config = dataclasses.replace(config, ckpt=local_ckpt) | |
| params, model = gradio_helpers.get_memory_cache( | |
| config, | |
| lambda: models.load_model(config), | |
| max_cache_size_bytes=MAX_RAM_CACHE, | |
| progress=progress, | |
| estimated_secs={ | |
| ('lit', 'B/16'): 1, | |
| ('lit', 'L/16'): 2.5, | |
| ('siglip', 'B/16'): 9, | |
| ('siglip', 'L/16'): 28, | |
| ('siglip', 'So400m/14'): 36, | |
| }.get((family, variant)) | |
| ) | |
| model: models.ContrastiveModel = model | |
| it = progress.tqdm(list(range(3)), desc='compute') | |
| logging.info('Opening image "%s"', image_path) | |
| with gradio_helpers.timed(f'opening image "{image_path}"'): | |
| image = PIL.Image.open(image_path) | |
| next(it) | |
| with gradio_helpers.timed('image features'): | |
| zimg, out = model.embed_images( | |
| params, model.preprocess_images([image]) | |
| ) | |
| next(it) | |
| with gradio_helpers.timed('text features'): | |
| prompts = prompts.split('\n') | |
| ztxt, out = model.embed_texts( | |
| params, model.preprocess_texts(prompts) | |
| ) | |
| next(it) | |
| t = model.get_temperature(out) | |
| if family == 'lit': | |
| text_probs = list(model.get_probabilities(zimg, ztxt, t, axis=-1)[0]) | |
| elif family == 'siglip': | |
| text_probs = list(model.get_probabilities(zimg, ztxt, t, bias=bias)[0]) | |
| state = list(zip(prompts, [round(p.item(), 3) for p in text_probs])) | |
| dt = time.monotonic() - t0 | |
| mem_n, mem_sz = gradio_helpers.get_memory_cache_info() | |
| disk_n, disk_sz = gradio_helpers.get_disk_cache_info() | |
| status = gr.Markdown( | |
| f'Computed inference in {dt:.1f} seconds (' | |
| f'memory cache {mem_n} items, {mem_sz/1e6:.1f} M, ' | |
| f'disk cache {disk_n} items, {disk_sz/1e6:.1f} M)') | |
| if 'b' in out: | |
| logging.info('model_name=%s default bias=%f', model_name, out['b']) | |
| return status, state | |
| def update_answers(state): | |
| """Generates visible sliders for answers.""" | |
| answers = [] | |
| for prompt, prob in state[:MAX_ANSWERS]: | |
| answers.append(gr.Slider(value=round(100*prob, 2), label=prompt, visible=True)) | |
| while len(answers) < MAX_ANSWERS: | |
| answers.append(gr.Slider(visible=False)) | |
| return answers | |
| def create_app(): | |
| """Creates demo UI.""" | |
| css = ''' | |
| .slider input[type="number"] { width: 5em; } | |
| #examples td.textbox > div { | |
| white-space: pre-wrap !important; | |
| text-align: left; | |
| } | |
| ''' | |
| with gr.Blocks(css=css) as demo: | |
| gr.Markdown('Gradio clone of the original [LiT demo](https://google-research.github.io/vision_transformer/lit/).') | |
| status = gr.Markdown() | |
| with gr.Row(): | |
| image = gr.Image(label='Image', type='filepath') | |
| source = gr.Markdown('', visible=False) | |
| state = gr.State([]) | |
| with gr.Column(): | |
| prompts = gr.Textbox(label='Prompts (press Shift-ENTER to add a prompt)') | |
| with gr.Row(): | |
| values = {} | |
| family = gr.Dropdown(value='lit', choices=list(MODEL_MAP), label='Model family') | |
| values['family'] = family.value | |
| # Unfortunately below reactive UI code is a bit convoluted, because: | |
| # 1. When e.g. `family.change()` updates `variant`, then that does not | |
| # trigger a `varaint.change()`. | |
| # 2. The widget values like `family.value` are *not* updated when the | |
| # widget is updated. Therefore, we keep a manual copy in `values`. | |
| def make_variant(family_value): | |
| choices = list(MODEL_MAP[family_value]) | |
| values['variant'] = choices[0] | |
| return gr.Dropdown(value=values['variant'], choices=choices, label='Variant') | |
| variant = make_variant(family.value) | |
| def make_res(family, variant): | |
| choices = list(MODEL_MAP[family][variant]) | |
| values['res'] = choices[0] | |
| return gr.Dropdown(value=values['res'], choices=choices, label='Resolution') | |
| res = make_res(family.value, variant.value) | |
| values['res'] = res.value | |
| def make_bias(family, variant, res): | |
| visible = family == 'siglip' | |
| value = { | |
| ('siglip', 'B/16', 224): -12.9, | |
| ('siglip', 'L/16', 256): -12.7, | |
| ('siglip', 'L/16', 256): -16.5, | |
| # ... | |
| }.get((family, variant, res), -10.0) | |
| return gr.Slider(value=value, minimum=-20, maximum=0, step=0.05, label='Bias', visible=visible) | |
| bias = make_bias(family.value, variant.value, res.value) | |
| values['bias'] = bias.value | |
| def family_changed(family): | |
| variant = list(MODEL_MAP[family])[0] | |
| res = list(MODEL_MAP[family][variant])[0] | |
| values['family'] = family | |
| values['variant'] = variant | |
| values['res'] = res | |
| return [ | |
| make_variant(family), | |
| make_res(family, variant), | |
| make_bias(family, variant, res), | |
| ] | |
| def variant_changed(variant): | |
| res = list(MODEL_MAP[values['family']][variant])[0] | |
| values['variant'] = variant | |
| values['res'] = res | |
| return [ | |
| make_res(values['family'], variant), | |
| make_bias(values['family'], variant, res), | |
| ] | |
| def res_changed(res): | |
| return make_bias(values['family'], values['variant'], res) | |
| family.change(family_changed, family, [variant, res, bias]) | |
| variant.change(variant_changed, variant, [res, bias]) | |
| res.change(res_changed, res, bias) | |
| # (end of code for reactive UI code) | |
| run = gr.Button('Run') | |
| answers = [ | |
| # Will be set to visible in `update_answers()`. | |
| gr.Slider(0, 100, 0, visible=False, elem_classes='slider') | |
| for _ in range(MAX_ANSWERS) | |
| ] | |
| # We want to avoid showing multiple progress bars, so we only update | |
| # a single `status` widget here, and store the computed information in | |
| # `state`... | |
| run.click( | |
| fn=compute, inputs=[image, prompts, family, variant, res, bias], outputs=[status, state]) | |
| # ... then we use `state` to update UI components without showing a | |
| # progress bar in their place. | |
| status.change(fn=update_answers, inputs=state, outputs=answers) | |
| info = json.load(urllib.request.urlopen(INFO_URL)) | |
| gr.Markdown('Note: below images have 224 px resolution only:') | |
| gr.Examples( | |
| examples=[ | |
| [ | |
| IMG_URL_FMT.format(ex['id']), | |
| ex['prompts'].replace(', ', '\n'), | |
| '[source](%s)' % ex['source'], | |
| ] | |
| for ex in info | |
| ], | |
| inputs=[image, prompts, source, license], | |
| outputs=answers, | |
| elem_id='examples', | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| logging.basicConfig(level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s') | |
| for k, v in os.environ.items(): | |
| logging.info('environ["%s"] = %r', k, v) | |
| models.setup() | |
| create_app().queue().launch() | |