diff --git a/deeplink_ext/internevo_ops/_rotary_embedding_npu.py b/deeplink_ext/internevo_ops/_rotary_embedding_npu.py index 4e27d04..28df9aa 100644 --- a/deeplink_ext/internevo_ops/_rotary_embedding_npu.py +++ b/deeplink_ext/internevo_ops/_rotary_embedding_npu.py @@ -2,11 +2,44 @@ import torch import torch_npu -from einops import rearrange +from einops import rearrange, repeat __all__ = ["ApplyRotaryEmb"] +def rotate_half(x, interleaved=False): + if not interleaved: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + else: + x1, x2 = x[..., ::2], x[..., 1::2] + return rearrange( + torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2 + ) + + +def apply_rotary_emb_torch(x, cos, sin, interleaved=False): + """ + x: (batch_size, seqlen, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) + """ + ro_dim = cos.shape[-1] * 2 + assert ro_dim <= x.shape[-1] + cos = repeat( + cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" + ) + sin = repeat( + sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" + ) + return torch.cat( + [ + x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, + x[..., ro_dim:], + ], + dim=-1, + ) + + # adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py#L35 class ApplyRotaryEmb(torch.autograd.Function): """ @@ -38,38 +71,59 @@ def forward( assert seqlen <= rotary_seqlen assert sin.shape == (rotary_seqlen, rotary_dim // 2) - re_cos = rearrange(cos[:seqlen], "s d -> s 1 d") - re_sin = rearrange(sin[:seqlen], "s d -> s 1 d") - - cat_cos = torch.cat([re_cos, re_cos], -1) - cat_sin = torch.cat([re_sin, re_sin], -1) - - rot = torch_npu.npu_rotary_mul(x[..., :rotary_dim], cat_cos, cat_sin) - ctx.save_for_backward(cat_cos, cat_sin) + if interleaved: + cos = cos[:seqlen] + sin = sin[:seqlen] + else: + # "s d -> 1 s 1 d" + cos = cos[:seqlen].unsqueeze(0).unsqueeze(2).repeat(1, 1, 1, 2) + sin = sin[:seqlen].unsqueeze(0).unsqueeze(2).repeat(1, 1, 1, 2) + ctx.save_for_backward(cos, sin) ctx.interleaved = interleaved ctx.in_place = in_place - if in_place: - x[..., :rotary_dim].copy_(rot) - return x + if interleaved: + out = apply_rotary_emb_torch(x, cos, sin, interleaved) + if in_place: + x.copy_(out) + return x + else: + return out else: - out = x.detach().clone() - if rotary_dim < head_dim and not in_place: + x_ro = x[..., :rotary_dim] + out_ro = torch_npu.npu_rotary_mul(x_ro, cos, sin) + if in_place: + x[..., :rotary_dim].copy_(out_ro) + return x + if rotary_dim < head_dim: + out = torch.empty_like(x) + out[..., :rotary_dim].copy_(out_ro) out[..., rotary_dim:].copy_(x[..., rotary_dim:]) - return out + return out + return out_ro @staticmethod - def backward(ctx, do): - cat_cos, cat_sin = ctx.saved_tensors - *_, seqlen, _, head_dim = do.shape - rotary_dim = cat_cos.shape[-1] - - dx_out = torch_npu.npu_rotary_mul( - do[..., :rotary_dim], cat_cos, torch.neg(cat_sin) - ) - if ctx.in_place: - do[..., :rotary_dim].copy_(dx_out) - return do, None, None, None, None + def backward(ctx, grad_out): + cos, sin = ctx.saved_tensors + rotary_dim = cos.shape[-1] + head_dim = grad_out.shape[-1] + if ctx.interleaved: + grad_input = apply_rotary_emb_torch( + grad_out, cos, torch.neg(sin), ctx.interleaved + ) + if ctx.in_place: + grad_out.copy_(grad_input) + return grad_out, None, None, None, None + else: + return grad_input, None, None, None, None else: - dx = do.detach().clone() - dx[..., :rotary_dim].copy_(dx_out) - return dx, None, None, None, None + grad_out_ro = grad_out[..., :rotary_dim] + grad_input_ro = torch_npu.npu_rotary_mul(grad_out_ro, cos, torch.neg(sin)) + if ctx.in_place: + grad_out[..., :rotary_dim].copy_(grad_input_ro) + return grad_out, None, None, None, None + if rotary_dim < head_dim: + grad_input = torch.empty_like(grad_out) + grad_input[..., :rotary_dim].copy_(grad_input_ro) + grad_input[..., rotary_dim:].copy_(grad_out[..., rotary_dim:]) + return grad_input, None, None, None, None + return grad_input_ro, None, None, None, None diff --git a/deeplink_ext/internevo_ops/rotary_embedding.py b/deeplink_ext/internevo_ops/rotary_embedding.py index 1a2a36d..7764b9b 100644 --- a/deeplink_ext/internevo_ops/rotary_embedding.py +++ b/deeplink_ext/internevo_ops/rotary_embedding.py @@ -4,8 +4,7 @@ platform_type = deeplink_ext_get_platform_type() if platform_type == PlatformType.TORCH_NPU: - # from ._rotary_embedding_npu import ApplyRotaryEmb - from .rotary_embedding_fallback import ApplyRotaryEmbTorch as ApplyRotaryEmb + from ._rotary_embedding_npu import ApplyRotaryEmb elif platform_type == PlatformType.TORCH_DIPU: from ._rotary_embedding_dipu import ApplyRotaryEmb else: diff --git a/tests/internevo/test_rotary_embedding.py b/tests/internevo/test_rotary_embedding.py index 981c2f0..722577e 100644 --- a/tests/internevo/test_rotary_embedding.py +++ b/tests/internevo/test_rotary_embedding.py @@ -8,40 +8,41 @@ def test_ApplyRotaryEmb(): input_dtype_list = [torch.float16, torch.bfloat16] - interleaved = False in_place_options = [False, True] + interleaved_options = [False, True] for input_dtype in input_dtype_list: for in_place in in_place_options: - input_ref = torch.randn( - 1, 64, 32, 64, dtype=input_dtype, device="cuda", requires_grad=True - ) - input_ext = input_ref.clone().detach().requires_grad_() - cos = torch.randn(64, 32, dtype=input_dtype, device="cuda") - sin = torch.randn(64, 32, dtype=input_dtype, device="cuda") + for interleaved in interleaved_options: + input_ref = torch.randn( + 1, 64, 32, 64, dtype=input_dtype, device="cuda", requires_grad=True + ) + input_ext = input_ref.clone().detach().requires_grad_() + cos = torch.randn(64, 32, dtype=input_dtype, device="cuda") + sin = torch.randn(64, 32, dtype=input_dtype, device="cuda") - output_ref, grad_ref = call_autograd_func( - ApplyRotaryEmbTorch, - "cuda", - input_dtype, - input_ref, - cos, - sin, - interleaved, - in_place, - ) - output_ext, grad_ext = call_autograd_func( - ApplyRotaryEmb, - "cuda", - input_dtype, - input_ext, - cos, - sin, - interleaved, - in_place, - ) - assert allclose( - output_ref, output_ext, rtol=1e-2, atol=5e-2 - ), f"When input dtype is {input_dtype} and in_place is {in_place}, ApplyRotaryEmb fails to pass the forward test!" - assert allclose( - grad_ref, grad_ext - ), f"When input dtype is {input_dtype} and in_place is {in_place}, ApplyRotaryEmb fails to pass the backward test!" + output_ref, grad_ref = call_autograd_func( + ApplyRotaryEmbTorch, + "cuda", + input_dtype, + input_ref, + cos, + sin, + interleaved, + in_place, + ) + output_ext, grad_ext = call_autograd_func( + ApplyRotaryEmb, + "cuda", + input_dtype, + input_ext, + cos, + sin, + interleaved, + in_place, + ) + assert allclose( + output_ref, output_ext, rtol=1e-2, atol=5e-2 + ), f"When input dtype is {input_dtype} and in_place is {in_place}, ApplyRotaryEmb fails to pass the forward test!" + assert allclose( + grad_ref, grad_ext + ), f"When input dtype is {input_dtype} and in_place is {in_place}, ApplyRotaryEmb fails to pass the backward test!"