import torch SEED = 1234 PARAMS = [ (16, 16, 64, 512), (16, 32, 96, 768), (16, 64, 128, 1024), (32, 16, 128, 512), (32, 32, 160, 768), (32, 64, 192, 1024), (48, 32, 176, 1024), (48, 64, 224, 1280), (64, 16, 192, 768), (64, 32, 224, 1024), (64, 128, 256, 2048), (80, 32, 240, 1280), (80, 64, 256, 1536), (96, 32, 256, 1536), (96, 64, 288, 2048), (96, 128, 320, 3072), (112, 64, 320, 2048), (112, 128, 352, 2560), (128, 32, 256, 1024), (128, 64, 320, 1536), (128, 128, 384, 3072), (160, 64, 320, 1536), (160, 128, 384, 2560), (192, 64, 384, 2048), (192, 128, 448, 3072), (192, 256, 512, 4096), ] DTYPE_TO_TOLS = { torch.float32: {"atol": 1e-4, "rtol": 1e-3}, torch.bfloat16: {"atol": 1e-2, "rtol": 1e-2}, torch.float16: {"atol": 1e-3, "rtol": 1e-3}, } DTYPES = list(DTYPE_TO_TOLS.keys())