| import torch | |
| def zero_module(module): | |
| """ | |
| Zero out the parameters of a module and return it. | |
| """ | |
| for p in module.parameters(): | |
| p.detach().zero_() | |
| return module | |
| class StackedRandomGenerator: | |
| def __init__(self, device, seeds): | |
| super().__init__() | |
| self.generators = [torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds] | |
| def randn(self, size, **kwargs): | |
| assert size[0] == len(self.generators) | |
| return torch.stack([torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators]) | |
| def randn_like(self, input): | |
| return self.randn(input.shape, dtype=input.dtype, layout=input.layout, device=input.device) | |
| def randint(self, *args, size, **kwargs): | |
| assert size[0] == len(self.generators) | |
| return torch.stack([torch.randint(*args, size=size[1:], generator=gen, **kwargs) for gen in self.generators]) | |