Spaces:
Sleeping
Sleeping
| """Wrapper for big_vision contrastive models. | |
| Before using any of the functions, make sure to call `setup()`. | |
| Choose one of the configs in `MODEL_CONFIGS` and then call `load_model()` to get | |
| the params and model wrapper. | |
| """ | |
| import dataclasses | |
| import enum | |
| import functools | |
| import importlib | |
| import os | |
| import subprocess | |
| import sys | |
| import tempfile | |
| import flax.linen as nn | |
| import jax | |
| import jax.numpy as jnp | |
| import ml_collections | |
| import numpy as np | |
| import PIL.Image | |
| import sentencepiece | |
| from tensorflow.io import gfile | |
| import transformers | |
| def _clone_git(url, destination_folder, commit_hash=None): | |
| subprocess.run([ | |
| 'git', 'clone', '--depth=1', | |
| url, destination_folder | |
| ], check=True) | |
| if commit_hash: | |
| subprocess.run(['git', '-C', destination_folder, 'checkout', commit_hash], check=True) | |
| def setup(commit_hash=None): | |
| for url, dst_name in ( | |
| ('https://github.com/google-research/big_vision', 'big_vision_repo'), | |
| ('https://github.com/google/flaxformer', 'flaxformer_repo'), | |
| ): | |
| dst_path = os.path.join(tempfile.gettempdir(), dst_name) | |
| if not os.path.exists(dst_path): | |
| _clone_git(url, dst_path, commit_hash) | |
| if not dst_path in sys.path: | |
| sys.path.insert(0, dst_path) | |
| class ContrastiveModelFamily(enum.Enum): | |
| LIT = 'lit' | |
| SIGLIP = 'siglip' | |
| def paper(self): | |
| return { | |
| self.LIT: 'https://arxiv.org/abs/2111.07991', | |
| self.SIGLIP: 'https://arxiv.org/abs/2303.15343', | |
| }[self] | |
| def __lt__(self, other): | |
| return self.value < other.value | |
| class ContrastiveModelConfig: | |
| """Desribes a `big_vision` contrastive model.""" | |
| family: ContrastiveModelFamily | |
| variant: str | |
| res: int | |
| textvariant: str | |
| embdim: int | |
| seqlen: int | |
| tokenizer: str | |
| vocab_size: int | |
| ckpt: str | |
| class ContrastiveModel: | |
| """Wraps a `big_vision` contrastive model.""" | |
| config: ContrastiveModelConfig | |
| flax_module: nn.Module | |
| tokenizer_sp: sentencepiece.SentencePieceProcessor | None | |
| tokenizer_bert: transformers.BertTokenizer | None | |
| def embed_images(self, params, images): | |
| assert getattr(images, 'ndim') == 4, 'Must call `.preprocess_images()`' | |
| zimg, _, out = self.flax_module.apply(dict(params=params), images, None) | |
| return zimg, out | |
| def embed_texts(self, params, texts): | |
| assert getattr(texts, 'ndim') == 2, 'Must call `.preprocess_texts()`' | |
| _, ztxt, out = self.flax_module.apply(dict(params=params), None, texts) | |
| return ztxt, out | |
| def preprocess_texts(self, texts): | |
| def tokenize_pad(text, seqlen=self.config.seqlen): | |
| if self.config.family == ContrastiveModelFamily.LIT: | |
| tokens = self.tokenizer_bert.encode(text, add_special_tokens=True)[:-1] # removes [SEP] | |
| tokens = tokens[:seqlen] | |
| return tokens + [0] * (seqlen - len(tokens)) | |
| if self.config.family == ContrastiveModelFamily.SIGLIP: | |
| tokens = self.tokenizer_sp.tokenize(text, add_eos=True) | |
| if len(tokens) >= seqlen: | |
| return tokens[:seqlen - 1] + [tok.eos_id()] # "sticky" eos | |
| return tokens + [0] * (seqlen - len(tokens)) | |
| return np.array([tokenize_pad(text) for text in texts]) | |
| def preprocess_images(self, images): | |
| if not isinstance(images, (list, tuple)): | |
| images = [images] | |
| def topil(image): | |
| if not isinstance(image, PIL.Image.Image): | |
| image = PIL.Image.fromarray(image) | |
| return image | |
| return np.array([ | |
| topil(image).resize([self.config.res, self.config.res]) | |
| for image in images | |
| ]) / 127.5 - 1.0 | |
| def get_bias(self, out): | |
| assert self.config.family == ContrastiveModelFamily.SIGLIP, self.config.family | |
| return out['b'].item() | |
| def get_temperature(self, out): | |
| return out['t'].item() | |
| def get_probabilities(self, zimg, ztxt, temperature, *, axis=None, bias=None): | |
| # Note: zimg, ztxt are already normalized. | |
| if self.config.family == ContrastiveModelFamily.LIT: | |
| assert bias is None | |
| assert axis in (-1, -2), 'Must specify axis: -1/-2=normalize texts/images' | |
| return jax.nn.softmax(zimg @ ztxt.T * temperature, axis=axis) | |
| if self.config.family == ContrastiveModelFamily.SIGLIP: | |
| assert axis is None | |
| assert bias is not None, 'Must specify bias.' | |
| return jax.nn.sigmoid(zimg @ ztxt.T * temperature + bias) | |
| def _make_config(family, variant, res, textvariant, ckpt, embdim, seqlen, vocab_size): | |
| if family == 'lit': | |
| tokenizer = ckpt.replace('.npz', '.txt') | |
| else: | |
| tokenizer = 'c4_en' | |
| return ContrastiveModelConfig( | |
| family=ContrastiveModelFamily(family), variant=variant, res=res, | |
| textvariant=textvariant, embdim=embdim, seqlen=seqlen, | |
| tokenizer=tokenizer, vocab_size=32_000, | |
| ckpt=ckpt, | |
| ) | |
| MODEL_CONFIGS = dict( | |
| lit_b16b=_make_config('lit', 'B/16', 224, 'B', 'gs://vit_models/lit/LiT-B16B.npz', 768, 16, 32_000), | |
| lit_l16l=_make_config('lit', 'L/16', 224, 'L', 'gs://vit_models/lit/LiT-L16L.npz', 1024, 16, 32_000), | |
| lit_b16s=_make_config('lit', 'L/16', 224, 'S', 'gs://vit_models/lit/LiT-L16S.npz', 1024, 16, 32_000), | |
| lit_b16ti=_make_config('lit', 'L/16', 224, 'Ti', 'gs://vit_models/lit/LiT-L16Ti.npz', 1024, 16, 32_000), | |
| siglip_b16b_224=_make_config('siglip', 'B/16', 224, 'B', 'gs://big_vision/siglip/webli_en_b16_224_63724782.npz', 768, 64, 32_000), | |
| siglip_b16b_256=_make_config('siglip', 'B/16', 256, 'B', 'gs://big_vision/siglip/webli_en_b16_256_60500360.npz', 768, 64, 32_000), | |
| siglip_b16b_384=_make_config('siglip', 'B/16', 384, 'B', 'gs://big_vision/siglip/webli_en_b16_384_68578854.npz', 768, 64, 32_000), | |
| siglip_b16b_512=_make_config('siglip', 'B/16', 512, 'B', 'gs://big_vision/siglip/webli_en_b16_512_68580893.npz', 768, 64, 32_000), | |
| siglip_l16l_256=_make_config('siglip', 'L/16', 256, 'L', 'gs://big_vision/siglip/webli_en_l16_256_60552751.npz', 1024, 64, 32_000), | |
| siglip_l16l_384=_make_config('siglip', 'L/16', 384, 'L', 'gs://big_vision/siglip/webli_en_l16_384_63634585.npz', 1024, 64, 32_000), | |
| siglip_so400m14so440m_224=_make_config('siglip', 'So400m/14', 224, 'So400m', 'gs://big_vision/siglip/webli_en_so400m_224_57633886.npz', 1152, 16, 32_000), | |
| siglip_so400m14so400m_384=_make_config('siglip', 'So400m/14', 384, 'So400m', 'gs://big_vision/siglip/webli_en_so400m_384_58765454.npz', 1152, 64, 32_000), | |
| ) | |
| def load_tokenizer_sp(name_or_path): | |
| tok = sentencepiece.SentencePieceProcessor() | |
| path = { | |
| 'c4_en': 'gs://t5-data/vocabs/cc_en.32000/sentencepiece.model', | |
| }.get(name_or_path, name_or_path) | |
| tok.LoadFromSerializedProto(gfile.GFile(path, 'rb').read()) | |
| return tok | |
| def load_tokenizer_bert(path): | |
| tok = sentencepiece.SentencePieceProcessor() | |
| if path.startswith('gs://'): | |
| dst = tempfile.mktemp() | |
| gfile.copy(path, dst) | |
| path = dst | |
| return transformers.BertTokenizer(path, do_lower_case=True) | |
| def load_model(config, check_params=False): | |
| """Loads `big_vision` model.""" | |
| assert isinstance(config, ContrastiveModelConfig), type(config) | |
| cfg = ml_collections.ConfigDict() | |
| cfg.image_model = 'vit' # TODO(lbeyer): remove later, default | |
| if config.family == ContrastiveModelFamily.LIT: | |
| cfg.text_model = 'proj.flaxformer.bert' | |
| cfg.image = dict(variant=config.variant, pool_type='tok', head_zeroinit=False) | |
| bert_config = {'B': 'base', 'L': 'large'}[config.textvariant] | |
| cfg.text = dict(config=bert_config, head_zeroinit=False) | |
| tokenizer_bert = load_tokenizer_bert(config.tokenizer) | |
| tokenizer_sp = None | |
| if config.variant == 'L/16': | |
| cfg.out_dim = (None, config.embdim) # (image_out_dim, text_out_dim) | |
| else: | |
| cfg.out_dim = (config.embdim, config.embdim) # (image_out_dim, text_out_dim) | |
| else: | |
| cfg.image = dict(variant=config.variant, pool_type='map') | |
| cfg.text_model = 'proj.image_text.text_transformer' # TODO(lbeyer): remove later, default | |
| cfg.text = dict(variant=config.textvariant, vocab_size=config.vocab_size) | |
| cfg.bias_init = -10.0 | |
| tokenizer_sp = load_tokenizer_sp(config.tokenizer) | |
| tokenizer_bert = None | |
| cfg.out_dim = (None, config.embdim) # (image_out_dim, text_out_dim) | |
| cfg.temperature_init = 10.0 | |
| model_mod = importlib.import_module( | |
| 'big_vision.models.proj.image_text.two_towers') | |
| model = model_mod.Model(**cfg) | |
| init_params = None # Faster but bypasses loading sanity-checks. | |
| if check_params: | |
| imgs = jnp.zeros([1, config.res, config.res, 3]) | |
| txts = jnp.zeros([1, config.seqlen], jnp.int32) | |
| init_params = model.init(jax.random.PRNGKey(0), imgs, txts)['params'] | |
| params_cpu = model_mod.load(init_params, config.ckpt, cfg) | |
| return params_cpu, ContrastiveModel( | |
| config=config, | |
| flax_module=model, | |
| tokenizer_sp=tokenizer_sp, | |
| tokenizer_bert=tokenizer_bert, | |
| ) | |