Spaces:
Runtime error
Runtime error
| # ------------------------------------------------------------------------ | |
| # Copyright (c) 2023-present, BAAI. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ------------------------------------------------------------------------ | |
| """Prompt encoder.""" | |
| import torch | |
| from torch import nn | |
| class PromptEncoder(nn.Module): | |
| """Module to encode geometric prompts.""" | |
| def __init__(self, embed_dim, image_size): | |
| super(PromptEncoder, self).__init__() | |
| self.img_size = [image_size] * 2 | |
| self.point_embed = nn.Embedding(5, embed_dim) # [bg, fg, lt, rb, pad] | |
| self.corner_labels = torch.tensor([[2, 3]], dtype=torch.int64) | |
| self.register_buffer("coord_matrix", torch.randn((2, embed_dim // 2))) | |
| self.img_pos = None | |
| def to_tensor(self, input): | |
| """Convert input to tensor.""" | |
| if input is None: | |
| return input | |
| if not isinstance(input, torch.Tensor): | |
| input = torch.from_numpy(input) | |
| if input.device != self.coord_matrix.device: | |
| input = input.to(device=self.coord_matrix.device) | |
| return input | |
| def to_points(self, points=None, boxes=None): | |
| """Convert points or boxes to point prompts.""" | |
| if points is not None: | |
| if isinstance(points, (tuple, list)): | |
| coords, labels = points | |
| else: | |
| coords, labels = points[:, :, :2], points[:, :, 2] | |
| coords = coords.__add__(0.5).__itruediv__(self.img_size[::-1]) | |
| coords = self.to_tensor(coords.clip(0, 1).astype("float32")) | |
| labels = self.to_tensor(labels.astype("int64")) | |
| return coords, labels | |
| if boxes is not None: | |
| coords = boxes.reshape((-1, 2, 2)) | |
| coords = coords.__add__(0.5).__itruediv__(self.img_size[::-1]) | |
| coords = self.to_tensor(coords.clip(0, 1).astype("float32")) | |
| labels = self.to_tensor(self.corner_labels) | |
| return coords, labels | |
| return None | |
| def encode_coords(self, coords): | |
| """Return the embedding for given coords.""" | |
| pi4, pi2 = 4 * 3.1415926, 2 * 3.1415926 | |
| if self.coord_matrix.dtype != torch.float32: | |
| self.coord_matrix = self.coord_matrix.float() | |
| rad = coords.mul(pi4).sub_(pi2) @ self.coord_matrix | |
| dtype = self.point_embed.weight.dtype | |
| return torch.cat([rad.sin(), rad.cos()], dim=-1).to(dtype=dtype) | |
| def encode_points(self, coords, labels): | |
| """Return the embedding for given points.""" | |
| embed = self.encode_coords(coords) | |
| embed.mul_(labels.ne(4).unsqueeze_(-1).float().to(dtype=embed.dtype)) | |
| return embed.add_(self.point_embed(labels)) | |
| def encode_grid(self, grid_size): | |
| """Return the embedding for a grid of specified size.""" | |
| grid = torch.ones(*grid_size, dtype=torch.float32) | |
| y = grid.cumsum(dim=0).sub_(0.5).div_(grid_size[0]) | |
| x = grid.cumsum(dim=1).sub_(0.5).div_(grid_size[1]) | |
| coords = self.to_tensor(torch.stack([x, y], dim=-1)) | |
| return self.encode_coords(coords) | |
| def forward(self, inputs): | |
| sparse_embeds = [] | |
| if inputs.get("boxes", None) is not None: | |
| coords, labels = self.to_points(boxes=inputs["boxes"]) | |
| sparse_embeds.append(self.encode_points(coords, labels)) | |
| if inputs.get("points", None) is not None: | |
| coords, labels = self.to_points(points=inputs["points"]) | |
| sparse_embeds.append(self.encode_points(coords, labels)) | |
| if len(sparse_embeds) > 1: | |
| sparse_embeds = [torch.cat(sparse_embeds, dim=1)] | |
| elif len(sparse_embeds) == 0: | |
| raise ValueError("Excepted ``points`` or ``boxes`` prompts.") | |
| img_embed_size = torch.Size(inputs["img_embeds"].shape[2:-1]) | |
| if self.img_pos is None or self.img_pos.shape[0] != img_embed_size.numel(): | |
| self.img_pos = self.encode_grid(img_embed_size).flatten(0, 1) | |
| return {"sparse_embeds": sparse_embeds[0], "img_pos": self.img_pos} | |