|
|
import os |
|
|
import random |
|
|
import unittest |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from torch.nn.functional import scaled_dot_product_attention |
|
|
|
|
|
from finetrainers.models.attention_dispatch import ( |
|
|
AttentionProvider, |
|
|
_AttentionProviderRegistry, |
|
|
_set_context_parallel_options, |
|
|
attention_dispatch, |
|
|
attention_provider, |
|
|
flash_attn_flash_attention, |
|
|
native_cudnn_attention, |
|
|
native_efficient_attention, |
|
|
native_flash_attention, |
|
|
) |
|
|
from finetrainers.parallel.ptd import _EquipartitionSharder |
|
|
|
|
|
|
|
|
def set_seed(seed): |
|
|
torch.manual_seed(seed) |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
np.random.seed(seed) |
|
|
random.seed(seed) |
|
|
|
|
|
|
|
|
def get_world_size(): |
|
|
if torch.distributed.is_initialized(): |
|
|
return torch.distributed.get_world_size() |
|
|
return int(os.environ.get("WORLD_SIZE", 1)) |
|
|
|
|
|
|
|
|
class AttentionDispatchTest(unittest.TestCase): |
|
|
@classmethod |
|
|
def setUpClass(cls): |
|
|
set_seed(0) |
|
|
|
|
|
def test_forward(self): |
|
|
if not torch.cuda.is_available(): |
|
|
self.skipTest("CUDA is not available") |
|
|
cuda_capability = torch.cuda.get_device_capability() |
|
|
|
|
|
query, key, value = self._create_dummy_inputs() |
|
|
|
|
|
all_providers = [ |
|
|
(AttentionProvider._NATIVE_MATH, 0), |
|
|
(AttentionProvider.NATIVE, 5e-3), |
|
|
(AttentionProvider.FLASH, 5e-3), |
|
|
(AttentionProvider.FLASH_VARLEN, 5e-3), |
|
|
(AttentionProvider.FLEX, 2e-2), |
|
|
(AttentionProvider._NATIVE_CUDNN, 5e-3), |
|
|
(AttentionProvider._NATIVE_EFFICIENT, 5e-3), |
|
|
(AttentionProvider._NATIVE_FLASH, 5e-3), |
|
|
(AttentionProvider.SAGE, 1e-1), |
|
|
(AttentionProvider.SAGE_VARLEN, 2e-0), |
|
|
(AttentionProvider._SAGE_QK_INT8_PV_FP16_CUDA, 2e-0), |
|
|
(AttentionProvider._SAGE_QK_INT8_PV_FP16_TRITON, 2e-0), |
|
|
(AttentionProvider.XFORMERS, 5e-3), |
|
|
] |
|
|
|
|
|
if cuda_capability >= (8, 9): |
|
|
all_providers.append((AttentionProvider._SAGE_QK_INT8_PV_FP8_CUDA, 2e-0)) |
|
|
if cuda_capability >= (9, 0): |
|
|
all_providers.append((AttentionProvider._SAGE_QK_INT8_PV_FP16_CUDA_SM90, 2e-0)) |
|
|
|
|
|
ref_output = None |
|
|
for i, (provider, threshold) in enumerate(all_providers): |
|
|
try: |
|
|
output = self._check_forward_pass(provider, query, key, value) |
|
|
if i == 0: |
|
|
ref_output = output.detach().clone() |
|
|
else: |
|
|
self.assertTrue( |
|
|
torch.allclose(output, ref_output, atol=threshold), f"Forward pass mismatch for {provider}" |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Warning: Forward pass test failed for {provider} with error: {e}") |
|
|
|
|
|
def test_backward(self): |
|
|
if not torch.cuda.is_available(): |
|
|
self.skipTest("CUDA is not available") |
|
|
|
|
|
query, key, value = self._create_dummy_inputs() |
|
|
|
|
|
selected_providers = [ |
|
|
AttentionProvider.FLASH, |
|
|
AttentionProvider.FLASH_VARLEN, |
|
|
AttentionProvider.FLEX, |
|
|
AttentionProvider.NATIVE, |
|
|
AttentionProvider.XFORMERS, |
|
|
] |
|
|
|
|
|
ref_output = None |
|
|
for i, provider in enumerate(selected_providers): |
|
|
try: |
|
|
output = self._check_backward_pass(provider, query, key, value) |
|
|
if i == 0: |
|
|
ref_output = output.detach().clone() |
|
|
else: |
|
|
if provider == AttentionProvider.FLEX: |
|
|
threshold = 1e-2 |
|
|
else: |
|
|
threshold = 1e-3 |
|
|
self.assertTrue( |
|
|
torch.allclose(output, ref_output, atol=threshold), f"Backward pass mismatch for {provider}" |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Warning: Backward pass test failed for {provider} with error: {e}") |
|
|
|
|
|
def _create_dummy_inputs( |
|
|
self, batch_size=2, num_heads=8, seq_len=256, head_dim=64, dtype=torch.bfloat16, device="cuda" |
|
|
): |
|
|
torch.manual_seed(0) |
|
|
query = torch.randn(batch_size, num_heads, seq_len, head_dim, dtype=dtype, device=device) |
|
|
key = torch.randn(batch_size, num_heads, seq_len, head_dim, dtype=dtype, device=device) |
|
|
value = torch.randn(batch_size, num_heads, seq_len, head_dim, dtype=dtype, device=device) |
|
|
return query, key, value |
|
|
|
|
|
def _check_forward_pass(self, provider: AttentionProvider, query, key, value): |
|
|
kwargs = {} |
|
|
if provider == AttentionProvider._SAGE_QK_INT8_PV_FP16_CUDA: |
|
|
kwargs["pv_accum_dtype"] = "fp32" |
|
|
with attention_provider(provider): |
|
|
output = attention_dispatch(query, key, value, attention_kwargs=kwargs) |
|
|
self.assertIsNotNone(output) |
|
|
self.assertEqual(output.shape, query.shape) |
|
|
return output |
|
|
|
|
|
def _check_backward_pass(self, provider: AttentionProvider, query, key, value): |
|
|
query.requires_grad_(True) |
|
|
key.requires_grad_(True) |
|
|
value.requires_grad_(True) |
|
|
|
|
|
with attention_provider(provider): |
|
|
output = attention_dispatch(query, key, value) |
|
|
loss = output.mean() |
|
|
loss.backward() |
|
|
|
|
|
self.assertTrue(query.grad is not None) |
|
|
self.assertTrue(key.grad is not None) |
|
|
self.assertTrue(value.grad is not None) |
|
|
|
|
|
query.grad.zero_() |
|
|
key.grad.zero_() |
|
|
value.grad.zero_() |
|
|
return output |
|
|
|
|
|
|
|
|
class RingAttentionTest(unittest.TestCase): |
|
|
@classmethod |
|
|
def setUpClass(cls): |
|
|
torch.distributed.init_process_group(backend="nccl") |
|
|
rank, world_size = torch.distributed.get_rank(), torch.distributed.get_world_size() |
|
|
|
|
|
cls.rank = rank |
|
|
cls.world_size = world_size |
|
|
torch.cuda.set_device(rank) |
|
|
cls.mesh = torch.distributed.device_mesh.init_device_mesh("cuda", (world_size,)) |
|
|
|
|
|
set_seed(0) |
|
|
cls.batch_size = 2 |
|
|
cls.num_heads = 8 |
|
|
cls.seq_len = 256 |
|
|
cls.head_dim = 64 |
|
|
cls.dtype = torch.bfloat16 |
|
|
cls.device = "cuda" |
|
|
|
|
|
_AttentionProviderRegistry._set_context_parallel( |
|
|
mesh=cls.mesh, convert_to_fp32=True, rotate_method="allgather" |
|
|
) |
|
|
_set_context_parallel_options(is_causal=False) |
|
|
|
|
|
cls.full_query = torch.randn( |
|
|
cls.batch_size, |
|
|
cls.num_heads, |
|
|
cls.seq_len * cls.world_size, |
|
|
cls.head_dim, |
|
|
dtype=cls.dtype, |
|
|
device=cls.device, |
|
|
requires_grad=True, |
|
|
) |
|
|
cls.full_key = torch.randn( |
|
|
cls.batch_size, |
|
|
cls.num_heads, |
|
|
cls.seq_len * cls.world_size, |
|
|
cls.head_dim, |
|
|
dtype=cls.dtype, |
|
|
device=cls.device, |
|
|
requires_grad=True, |
|
|
) |
|
|
cls.full_value = torch.randn( |
|
|
cls.batch_size, |
|
|
cls.num_heads, |
|
|
cls.seq_len * cls.world_size, |
|
|
cls.head_dim, |
|
|
dtype=cls.dtype, |
|
|
device=cls.device, |
|
|
requires_grad=True, |
|
|
) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
torch.distributed.broadcast(cls.full_query, src=0) |
|
|
torch.distributed.broadcast(cls.full_key, src=0) |
|
|
torch.distributed.broadcast(cls.full_value, src=0) |
|
|
|
|
|
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): |
|
|
reference_output = scaled_dot_product_attention(cls.full_query, cls.full_key, cls.full_value) |
|
|
|
|
|
cls.reference_output = reference_output.detach().clone() |
|
|
reference_output.sum().backward() |
|
|
|
|
|
cls.query, cls.key, cls.value = ( |
|
|
_EquipartitionSharder.shard(x, dim=2, mesh=cls.mesh).detach().clone() |
|
|
for x in (cls.full_query, cls.full_key, cls.full_value) |
|
|
) |
|
|
|
|
|
@classmethod |
|
|
def tearDownClass(cls): |
|
|
torch.distributed.destroy_process_group() |
|
|
|
|
|
def _test_forward_native_cudnn_attention(self, atol: float = 1e-3): |
|
|
output = native_cudnn_attention(self.query, self.key, self.value) |
|
|
output = _EquipartitionSharder.unshard(output, dim=2, mesh=self.mesh) |
|
|
self.assertEqual(output.shape, self.reference_output.shape) |
|
|
self.assertTrue(torch.allclose(output, self.reference_output, atol=atol)) |
|
|
|
|
|
def _test_forward_native_efficient_attention(self, atol: float = 1e-3): |
|
|
output = native_efficient_attention(self.query, self.key, self.value) |
|
|
output = _EquipartitionSharder.unshard(output, dim=2, mesh=self.mesh) |
|
|
self.assertEqual(output.shape, self.reference_output.shape) |
|
|
self.assertTrue(torch.allclose(output, self.reference_output, atol=atol)) |
|
|
|
|
|
def _test_forward_native_flash_attention(self, atol: float = 1e-3): |
|
|
output = native_flash_attention(self.query, self.key, self.value) |
|
|
output = _EquipartitionSharder.unshard(output, dim=2, mesh=self.mesh) |
|
|
self.assertEqual(output.shape, self.reference_output.shape) |
|
|
self.assertTrue(torch.allclose(output, self.reference_output, atol=atol)) |
|
|
|
|
|
def _test_forward_flash_attn_flash_attention(self, atol: float = 1e-3): |
|
|
output = flash_attn_flash_attention(self.query, self.key, self.value) |
|
|
output = _EquipartitionSharder.unshard(output, dim=2, mesh=self.mesh) |
|
|
self.assertEqual(output.shape, self.reference_output.shape) |
|
|
self.assertTrue(torch.allclose(output, self.reference_output, atol=atol)) |
|
|
|
|
|
def _test_backward_native_cudnn_attention(self, atol: float = 1e-3): |
|
|
query, key, value = (x.detach().clone() for x in (self.query, self.key, self.value)) |
|
|
query.requires_grad = True |
|
|
key.requires_grad = True |
|
|
value.requires_grad = True |
|
|
output = native_cudnn_attention(query, key, value) |
|
|
output = _EquipartitionSharder.unshard(output, dim=2, mesh=self.mesh) |
|
|
output.sum().backward() |
|
|
with torch.no_grad(): |
|
|
q_g, k_g, v_g = ( |
|
|
_EquipartitionSharder.shard(x, dim=2, mesh=self.mesh) |
|
|
for x in (self.full_query.grad, self.full_key.grad, self.full_value.grad) |
|
|
) |
|
|
self.assertTrue(torch.allclose(query.grad, q_g, atol=atol)) |
|
|
self.assertTrue(torch.allclose(key.grad, k_g, atol=atol)) |
|
|
self.assertTrue(torch.allclose(value.grad, v_g, atol=atol)) |
|
|
|
|
|
def _test_backward_native_efficient_attention(self, atol: float = 1e-3): |
|
|
query, key, value = (x.detach().clone() for x in (self.query, self.key, self.value)) |
|
|
query.requires_grad = True |
|
|
key.requires_grad = True |
|
|
value.requires_grad = True |
|
|
output = native_efficient_attention(query, key, value) |
|
|
output = _EquipartitionSharder.unshard(output, dim=2, mesh=self.mesh) |
|
|
output.sum().backward() |
|
|
with torch.no_grad(): |
|
|
q_g, k_g, v_g = ( |
|
|
_EquipartitionSharder.shard(x, dim=2, mesh=self.mesh) |
|
|
for x in (self.full_query.grad, self.full_key.grad, self.full_value.grad) |
|
|
) |
|
|
self.assertTrue(torch.allclose(query.grad, q_g, atol=atol)) |
|
|
self.assertTrue(torch.allclose(key.grad, k_g, atol=atol)) |
|
|
self.assertTrue(torch.allclose(value.grad, v_g, atol=atol)) |
|
|
|
|
|
def _test_backward_native_flash_attention(self, atol: float = 1e-3): |
|
|
query, key, value = (x.detach().clone() for x in (self.query, self.key, self.value)) |
|
|
query.requires_grad = True |
|
|
key.requires_grad = True |
|
|
value.requires_grad = True |
|
|
output = native_flash_attention(query, key, value) |
|
|
output = _EquipartitionSharder.unshard(output, dim=2, mesh=self.mesh) |
|
|
output.sum().backward() |
|
|
with torch.no_grad(): |
|
|
q_g, k_g, v_g = ( |
|
|
_EquipartitionSharder.shard(x, dim=2, mesh=self.mesh) |
|
|
for x in (self.full_query.grad, self.full_key.grad, self.full_value.grad) |
|
|
) |
|
|
self.assertTrue(torch.allclose(query.grad, q_g, atol=atol)) |
|
|
self.assertTrue(torch.allclose(key.grad, k_g, atol=atol)) |
|
|
self.assertTrue(torch.allclose(value.grad, v_g, atol=atol)) |
|
|
|
|
|
def _test_backward_flash_attn_flash_attention(self, atol: float = 1e-3): |
|
|
query, key, value = (x.detach().clone() for x in (self.query, self.key, self.value)) |
|
|
query.requires_grad = True |
|
|
key.requires_grad = True |
|
|
value.requires_grad = True |
|
|
output = flash_attn_flash_attention(query, key, value) |
|
|
output = _EquipartitionSharder.unshard(output, dim=2, mesh=self.mesh) |
|
|
output.sum().backward() |
|
|
with torch.no_grad(): |
|
|
q_g, k_g, v_g = ( |
|
|
_EquipartitionSharder.shard(x, dim=2, mesh=self.mesh) |
|
|
for x in (self.full_query.grad, self.full_key.grad, self.full_value.grad) |
|
|
) |
|
|
self.assertTrue(torch.allclose(query.grad, q_g, atol=atol)) |
|
|
self.assertTrue(torch.allclose(key.grad, k_g, atol=atol)) |
|
|
self.assertTrue(torch.allclose(value.grad, v_g, atol=atol)) |
|
|
|
|
|
|
|
|
class RingAttentionCPTesterMixin: |
|
|
def test_forward_native_cudnn_attention(self): |
|
|
self._test_forward_native_cudnn_attention(atol=1e-2) |
|
|
|
|
|
def test_forward_native_efficient_attention(self): |
|
|
self._test_forward_native_efficient_attention(atol=1e-2) |
|
|
|
|
|
def test_forward_native_flash_attention(self): |
|
|
self._test_forward_native_flash_attention(atol=1e-2) |
|
|
|
|
|
def test_forward_flash_attn_flash_attention(self): |
|
|
self._test_forward_flash_attn_flash_attention(atol=1e-2) |
|
|
|
|
|
def test_backward_native_cudnn_attention(self): |
|
|
atol = 1e-2 * self.world_size |
|
|
self._test_backward_native_cudnn_attention(atol=atol) |
|
|
|
|
|
def test_backward_native_efficient_attention(self): |
|
|
atol = 1e-2 * self.world_size |
|
|
self._test_backward_native_efficient_attention(atol=atol) |
|
|
|
|
|
def test_backward_native_flash_attention(self): |
|
|
atol = 1e-2 * self.world_size |
|
|
self._test_backward_native_flash_attention(atol=atol) |
|
|
|
|
|
@unittest.skip( |
|
|
"""query diff: 0.298828125, key diff: 2.09375, value diff: 0.68359375; Needs further investigation""" |
|
|
) |
|
|
def test_backward_flash_attn_flash_attention(self): |
|
|
|
|
|
atol = 1.5e-1 * self.world_size |
|
|
self._test_backward_flash_attn_flash_attention(atol=atol) |
|
|
|
|
|
|
|
|
@unittest.skipIf( |
|
|
not torch.cuda.is_available() or get_world_size() != 2, "CUDA is not available or world size is not 2" |
|
|
) |
|
|
class RingAttentionCP2Test(RingAttentionTest, RingAttentionCPTesterMixin): |
|
|
pass |
|
|
|
|
|
|
|
|
@unittest.skipIf( |
|
|
not torch.cuda.is_available() or get_world_size() != 4, "CUDA is not available or world size is not 4" |
|
|
) |
|
|
class RingAttentionCP4Test(RingAttentionTest, RingAttentionCPTesterMixin): |
|
|
pass |
|
|
|