Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) 2025 Ye Liu. Licensed under the BSD-3-Clause License. | |
| import torch | |
| import torch.nn as nn | |
| class BufferList(nn.Module): | |
| def __init__(self, buffers): | |
| super(BufferList, self).__init__() | |
| for i, buffer in enumerate(buffers): | |
| self.register_buffer(str(i), buffer, persistent=False) | |
| def __len__(self): | |
| return len(self._buffers) | |
| def __iter__(self): | |
| return iter(self._buffers.values()) | |
| class PointGenerator(nn.Module): | |
| def __init__(self, strides, buffer_size, offset=False): | |
| super(PointGenerator, self).__init__() | |
| reg_range, last = [], 0 | |
| for stride in strides[1:]: | |
| reg_range.append((last, stride)) | |
| last = stride | |
| reg_range.append((last, float('inf'))) | |
| self.strides = strides | |
| self.reg_range = reg_range | |
| self.buffer_size = buffer_size | |
| self.offset = offset | |
| self.buffer = self._cache_points() | |
| def _cache_points(self): | |
| buffer_list = [] | |
| for stride, reg_range in zip(self.strides, self.reg_range): | |
| reg_range = torch.Tensor([reg_range]) | |
| lv_stride = torch.Tensor([stride]) | |
| points = torch.arange(0, self.buffer_size, stride)[:, None] | |
| if self.offset: | |
| points += 0.5 * stride | |
| reg_range = reg_range.repeat(points.size(0), 1) | |
| lv_stride = lv_stride.repeat(points.size(0), 1) | |
| buffer_list.append(torch.cat((points, reg_range, lv_stride), dim=1)) | |
| buffer = BufferList(buffer_list) | |
| return buffer | |
| def forward(self, pymid): | |
| assert self.strides[0] == 1 | |
| # video_size = pymid[0].size(1) | |
| points = [] | |
| sizes = [p.size(1) for p in pymid] + [0] * (len(self.buffer) - len(pymid)) | |
| for size, buffer in zip(sizes, self.buffer): | |
| if size == 0: | |
| continue | |
| assert size <= buffer.size(0), 'reached max buffer size' | |
| point = buffer[:size, :].clone() | |
| # point[:, 0] /= video_size | |
| points.append(point) | |
| points = torch.cat(points) | |
| return points | |