| import torch | |
| SEED = 1234 | |
| PARAMS = [ | |
| (16, 16, 64, 512), | |
| (16, 32, 96, 768), | |
| (16, 64, 128, 1024), | |
| (32, 16, 128, 512), | |
| (32, 32, 160, 768), | |
| (32, 64, 192, 1024), | |
| (48, 32, 176, 1024), | |
| (48, 64, 224, 1280), | |
| (64, 16, 192, 768), | |
| (64, 32, 224, 1024), | |
| (64, 128, 256, 2048), | |
| (80, 32, 240, 1280), | |
| (80, 64, 256, 1536), | |
| (96, 32, 256, 1536), | |
| (96, 64, 288, 2048), | |
| (96, 128, 320, 3072), | |
| (112, 64, 320, 2048), | |
| (112, 128, 352, 2560), | |
| (128, 32, 256, 1024), | |
| (128, 64, 320, 1536), | |
| (128, 128, 384, 3072), | |
| (160, 64, 320, 1536), | |
| (160, 128, 384, 2560), | |
| (192, 64, 384, 2048), | |
| (192, 128, 448, 3072), | |
| (192, 256, 512, 4096), | |
| ] | |
| DTYPE_TO_TOLS = { | |
| torch.float32: {"atol": 1e-4, "rtol": 1e-3}, | |
| torch.bfloat16: {"atol": 1e-2, "rtol": 1e-2}, | |
| torch.float16: {"atol": 1e-3, "rtol": 1e-3}, | |
| } | |
| DTYPES = list(DTYPE_TO_TOLS.keys()) | |