Spaces:
Sleeping
Sleeping
| from typing import List, Tuple | |
| import torch | |
| from einops import rearrange | |
| from PIL import Image | |
| from torch.nn import functional as F | |
| from torchvision.transforms.v2 import InterpolationMode | |
| from torchvision.transforms.v2.functional import normalize | |
| from torchvision.transforms.v2.functional import resize as tv_resize | |
| from torchvision.transforms.v2.functional import to_dtype, to_image | |
| from .layers import attn, layer_norm, linear, mlp | |
| from .weights import VisionModel, load_from_safetensors | |
| def im_resize( | |
| image: Image.Image, | |
| size: List[int], | |
| interpolation: InterpolationMode = InterpolationMode.BICUBIC, | |
| ) -> Image.Image: | |
| """ | |
| The 'resize' function from torchvision has bad type signatures. | |
| it accepts both PIL images and torch tensors, but the type signature | |
| only allows tensors. | |
| """ | |
| return tv_resize( | |
| image, # type: ignore | |
| size, | |
| InterpolationMode.BICUBIC, | |
| ) | |
| def create_patches( | |
| image: Image.Image, image_patch_size=378 | |
| ) -> Tuple[List[Image.Image], Tuple[int, int]]: | |
| """ | |
| Split the given image into a variable number of patches depending upon its | |
| resolution. | |
| """ | |
| # Start off with the global patch. | |
| patches = [im_resize(image, [image_patch_size, image_patch_size])] | |
| # Find the closest resolution template. | |
| res_templates = [(1, 2), (2, 1), (2, 2)] | |
| im_width, im_height = image.size | |
| max_dim = max(im_width, im_height) | |
| if max_dim < image_patch_size * 1.4: | |
| # If the image is already small, we just do a single patch that is a | |
| # duplicate of the global patch. This creates a small amount of | |
| # redundant computation now, but it is simpler and future-proofs us | |
| # if/when we condition the vision encoder on the patch type. | |
| res_template = (1, 1) | |
| patches.append(patches[0]) | |
| else: | |
| aspect_ratio = im_width / im_height | |
| res_template = min( | |
| res_templates, key=lambda size: abs((size[1] / size[0]) - aspect_ratio) | |
| ) | |
| # TODO: Actually implement patching... just going to put in the global | |
| # patch for now to make progress on other aspects. | |
| patches.append(patches[0]) | |
| return patches, res_template | |
| def encode_image(image: Image.Image, weights: VisionModel) -> torch.Tensor: | |
| patches, res_template = create_patches(image.convert("RGB")) | |
| patches = torch.stack( | |
| [ | |
| normalize( | |
| to_dtype(to_image(patch), torch.float16, scale=True), | |
| mean=[0.5, 0.5, 0.5], | |
| std=[0.5, 0.5, 0.5], | |
| ) | |
| for patch in patches | |
| ] | |
| ) | |
| outputs = vision_encoder(patches, weights) | |
| # TODO: Merge sub-image patch outputs properly... for now we'll just assume | |
| # that the global patch is repeated. | |
| assert outputs.shape[0] == 2, "Expected single image patch." | |
| outputs = torch.cat([outputs[0], outputs[1]], dim=-1) | |
| return mlp(outputs, weights.proj_mlp) | |
| def vision_encoder(input_BCHW: torch.Tensor, w: VisionModel): | |
| x = rearrange( | |
| input_BCHW, | |
| "b c (h p1) (w p2) -> b (h w) (c p1 p2)", | |
| p1=w.patch_size, | |
| p2=w.patch_size, | |
| ) # B3HW -> B(HxW)(3xP1xP2), aka BTC | |
| x = linear(x, w.patch_emb) | |
| x = x + w.pos_emb | |
| for block in w.blocks: | |
| x = x + attn(layer_norm(x, block.ln1), block.attn) | |
| x = x + mlp(layer_norm(x, block.ln2), block.mlp) | |
| x = layer_norm(x, w.post_ln) | |
| return x | |