RoyYang0714's picture
feat: Try to build everything locally.
9b33fca
"""ToTensor transformation."""
import numpy as np
import torch
from vis4d.data.const import CommonKeys as K
from vis4d.data.typing import DictData
from .base import Transform
def _replace_arrays(data: DictData) -> None:
"""Replace numpy arrays with tensors."""
for key in data.keys():
if key in [K.images, K.original_images]:
if not data[key].flags.c_contiguous:
data[key] = np.ascontiguousarray(
data[key].transpose(0, 3, 1, 2)
)
data[key] = torch.from_numpy(data[key])
else:
data[key] = (
torch.from_numpy(data[key])
.permute(0, 3, 1, 2)
.contiguous()
)
elif isinstance(data[key], np.ndarray):
data[key] = torch.from_numpy(data[key])
elif isinstance(data[key], dict):
_replace_arrays(data[key])
elif isinstance(data[key], list):
for i, entry in enumerate(data[key]):
if isinstance(entry, np.ndarray):
data[key][i] = torch.from_numpy(entry)
@Transform("data", "data")
class ToTensor:
"""Transform all entries in a list of DataDict from numpy to torch.
Note that we reshape K.images from NHWC to NCHW.
"""
def __call__(self, batch: list[DictData]) -> list[DictData]:
"""Transform all entries to tensor."""
for data in batch:
_replace_arrays(data)
return batch