From 23f7a965cbbcee593c061091ab2edf55486dfc76 Mon Sep 17 00:00:00 2001 From: zhangzefeng92 Date: Tue, 26 Mar 2024 17:05:10 +0800 Subject: [PATCH 01/39] update cpp ext of rms norm --- csrc/extensions.cpp | 62 +++++++++---------- .../internlm_ops/rms_norm/deeplink.py | 48 ++++++++------ tests/test_rms_lightlm.py | 36 ++++++++--- 3 files changed, 85 insertions(+), 61 deletions(-) diff --git a/csrc/extensions.cpp b/csrc/extensions.cpp index 8b4c6389..d3b32ada 100644 --- a/csrc/extensions.cpp +++ b/csrc/extensions.cpp @@ -27,50 +27,46 @@ namespace dipu::dipu_ext { -namespace { -at::IntArrayRef optionalIntArrayToIntArrayRefOrDefault( - const OptionalIntArray& opt, at::IntArrayRef def) { - if (opt) { - return {*opt}; - } - return def; -} - -} // namespace - -auto extRmsNorm(const at::Tensor& input, +auto extRmsNorm(at::Tensor& output, + at::Tensor& inv_rms, + const at::Tensor& input, const OptionalIntArray& normalized_shape, - const at::Tensor& weight, const at::Tensor& bias, double eps) { - at::OptionalIntArrayRef normalized_shape_at = - optionalIntArrayToIntArrayRefOrDefault(normalized_shape, weight.sizes()); - auto input_shape = input.sizes(); - std::vector input_size(input_shape.begin(), input_shape.end()); - input_size.back() = 1; - auto inv_rms = at::empty(input_size, input.options()); - auto output = at::empty_like(input); - callDiopi(diopiRMSNorm, output, inv_rms, input, normalized_shape_at, weight, - bias, eps); + const at::Tensor& weight, + const at::Tensor& bias, + double eps) { + at::OptionalIntArrayRef normalized_shape_at = *normalized_shape; + callDiopi(diopiRMSNorm, output, inv_rms, input, normalized_shape_at, weight, bias, eps); return std::make_tuple(std::move(output), std::move(inv_rms)); } -auto extRmsNormBackward(const at::Tensor& input, const at::Tensor& grad_output, + +auto extRmsNormBackward(at::Tensor& grad_input, + at::Tensor& grad_weight, + at::Tensor& grad_bias, + const at::Tensor& grad_output, + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& bias, const at::Tensor& inv_rms, const OptionalIntArray& normalized_shape, - const at::Tensor& weight, const at::Tensor& bias, double eps) { - at::OptionalIntArrayRef normalized_shape_at = - optionalIntArrayToIntArrayRefOrDefault(normalized_shape, weight.sizes()); - auto grad_input = at::empty_like(grad_output); - auto grad_weight = at::empty_like(weight); - auto grad_bias = at::empty_like(bias); - callDiopi(diopiRMSNormBackward, grad_input, grad_weight, grad_bias, - grad_output, input, weight, bias, inv_rms, normalized_shape_at, + at::OptionalIntArrayRef normalized_shape_at = *normalized_shape; + callDiopi(diopiRMSNormBackward, + grad_input, + grad_weight, + grad_bias, + grad_output, + input, + weight, + bias, + inv_rms, + normalized_shape_at, eps); - return std::make_tuple(std::move(grad_input), std::move(grad_weight), - std::move(grad_bias)); + return std::make_tuple(std::move(grad_input), std::move(grad_weight), std::move(grad_bias)); } + void extApplyRotary(at::Tensor output, const at::Tensor& input, const at::Tensor& cos, const at::Tensor& sin, const bool conj, const bool interleaved) { diff --git a/deeplink_ext/internlm_ops/rms_norm/deeplink.py b/deeplink_ext/internlm_ops/rms_norm/deeplink.py index 406dc4aa..f9342fc0 100644 --- a/deeplink_ext/internlm_ops/rms_norm/deeplink.py +++ b/deeplink_ext/internlm_ops/rms_norm/deeplink.py @@ -10,7 +10,10 @@ class _DeepLinkRMSNormFunction(torch.autograd.Function): @staticmethod def forward(ctx, hidden_states, weight, bias, eps): - output, inv_rms = ext.rms_norm(hidden_states, None, weight, bias, eps) + output = torch.empty_like(hidden_states) + inv_rms_shape = list(hidden_states.shape[:-1], 1) + inv_rms = torch.empty(inv_rms_shape, dtype=hidden_states.dtype, device=hidden_states.device) + ext.rms_norm(output, inv_rms, hidden_states, None, weight, bias, eps) ctx.save_for_backward(hidden_states, inv_rms, weight, bias, torch.tensor(eps)) return output @@ -19,24 +22,25 @@ def forward(ctx, hidden_states, weight, bias, eps): def backward(ctx, grad_output): hidden_states, inv_rms, weight, bias, eps_tensor = ctx.saved_tensors eps = eps_tensor.item() - grad_input, grad_weight, grad_bias = ext.rms_norm_backward( - hidden_states, grad_output, inv_rms, None, weight, bias, eps - ) + + grad_input = torch.empty_like(hidden_states) + grad_weight = torch.empty_like(weight) + grad_bias = torch.empty_like(bias) + + ext.rms_norm_backward(grad_input, grad_weight, grad_bias, grad_output, hidden_states, weight, bias, inv_rms, None, eps) return grad_input, grad_weight, grad_bias, None class _DeepLinkRMSNormFunctionWithNormalizedShape(torch.autograd.Function): @staticmethod def forward(ctx, hidden_states, weight, bias, eps, normalized_shape): - output, inv_rms = ext.rms_norm( - hidden_states.float(), normalized_shape, weight.float(), bias.float(), eps - ) + output = torch.empty_like(hidden_states, dtype=torch.float32) + inv_rms_shape = list(hidden_states.shape[:-1], 1) + inv_rms = torch.empty(inv_rms_shape, dtype=torch.float32, device=hidden_states.device) + ext.rms_norm(output, inv_rms, hidden_states.float(), normalized_shape, weight.float(), bias.float(), eps) output = output.half() inv_rms = inv_rms.half() ctx.save_for_backward(hidden_states, inv_rms, weight, bias, torch.tensor(eps)) - hidden_states = hidden_states.half() - weight = weight.half() - bias = bias.half() ctx.intermediate_results = normalized_shape return output @@ -45,19 +49,23 @@ def backward(ctx, grad_output): hidden_states, inv_rms, weight, bias, eps_tensor = ctx.saved_tensors eps = eps_tensor.item() normalized_shape = ctx.intermediate_results - hidden_states = hidden_states.float() - inv_rms = inv_rms.float() - weight = weight.float() - bias = bias.float() - grad_output = grad_output.float() - grad_input, grad_weight, grad_bias = ext.rms_norm_backward( - hidden_states, grad_output, inv_rms, normalized_shape, weight, bias, eps - ) + + grad_input = torch.empty_like(hidden_states, dtype=torch.float32) + grad_weight = torch.empty_like(weight, dtype=torch.float32) + grad_bias = torch.empty_like(bias, dtype=torch.float32) + ext.rms_norm_backward(grad_input, + grad_weight, + grad_bias, + grad_output.float(), + hidden_states.float(), + weight.float(), + bias.float(), + inv_rms.float(), + normalized_shape, + eps) grad_output = grad_output.half() hidden_states = hidden_states.half() inv_rms = inv_rms.half() - weight = weight.half() - bias = bias.half() return grad_input, grad_weight, grad_bias, None, None diff --git a/tests/test_rms_lightlm.py b/tests/test_rms_lightlm.py index ea9f66b3..c3bec42a 100644 --- a/tests/test_rms_lightlm.py +++ b/tests/test_rms_lightlm.py @@ -17,16 +17,36 @@ normalized_shape = torch.tensor([5, 5], dtype=torch.long).cuda() print(input.is_dipu) -output, inv_rms = ext.rms_norm(input, None, weight, bias, 1e-6) +# output, inv_rms = ext.rms_norm(input, None, weight, bias, 1e-6) +# output1, inv_rms1 = ext.rms_norm(input, weight.shape, weight, bias, 1e-6) + +output = torch.empty_like(input) +inv_rms_shape = list(input.shape[:-1]) + [1] +inv_rms = torch.empty(inv_rms_shape, dtype=torch.float32, device=input.device) +ext.rms_norm(output, inv_rms, input, weight.shape, weight, bias, 1e-6) + + # 使用 RMS normalization 反向传播 -grad_input, grad_weight, grad_bias = ext.rms_norm_backward( - input, grad_output, inv_rms, None, weight, bias, 1e-6 +grad_input = torch.empty_like(grad_output) +grad_weight = torch.empty_like(weight) +grad_bias = torch.empty_like(bias) +ext.rms_norm_backward( + grad_input, + grad_weight, + grad_bias, + grad_output, + input, + weight, + bias, + inv_rms, + weight.shape, + 1e-6 ) -print("Output:", output) -print("Grad Input:", grad_input) -print("Grad Weight:", grad_weight) -print("Grad Bias:", grad_bias) +# print("Output:", output) +# print("Grad Input:", grad_input) +# print("Grad Weight:", grad_weight) +# print("Grad Bias:", grad_bias) b = input * torch.rsqrt(input.pow(2).mean(-1, keepdim=True) + 1e-6) * weight -assert torch.allclose(output, b) +assert torch.allclose(output, b) \ No newline at end of file From b0aa03d1eb384d9de969c1b668be3e37c10376e1 Mon Sep 17 00:00:00 2001 From: Zhangzefeng Date: Wed, 27 Mar 2024 09:08:52 +0800 Subject: [PATCH 02/39] Update test_rms_lightlm.py --- tests/test_rms_lightlm.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/tests/test_rms_lightlm.py b/tests/test_rms_lightlm.py index c3bec42a..316b3c9d 100644 --- a/tests/test_rms_lightlm.py +++ b/tests/test_rms_lightlm.py @@ -17,9 +17,6 @@ normalized_shape = torch.tensor([5, 5], dtype=torch.long).cuda() print(input.is_dipu) -# output, inv_rms = ext.rms_norm(input, None, weight, bias, 1e-6) -# output1, inv_rms1 = ext.rms_norm(input, weight.shape, weight, bias, 1e-6) - output = torch.empty_like(input) inv_rms_shape = list(input.shape[:-1]) + [1] inv_rms = torch.empty(inv_rms_shape, dtype=torch.float32, device=input.device) @@ -44,9 +41,5 @@ 1e-6 ) -# print("Output:", output) -# print("Grad Input:", grad_input) -# print("Grad Weight:", grad_weight) -# print("Grad Bias:", grad_bias) b = input * torch.rsqrt(input.pow(2).mean(-1, keepdim=True) + 1e-6) * weight -assert torch.allclose(output, b) \ No newline at end of file +assert torch.allclose(output, b) From 0517125d9670b55350dd4f52f7bd887338c00041 Mon Sep 17 00:00:00 2001 From: Zhangzefeng Date: Wed, 27 Mar 2024 09:09:47 +0800 Subject: [PATCH 03/39] Update test_rms_lightlm.py --- tests/test_rms_lightlm.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_rms_lightlm.py b/tests/test_rms_lightlm.py index 316b3c9d..35915e2d 100644 --- a/tests/test_rms_lightlm.py +++ b/tests/test_rms_lightlm.py @@ -41,5 +41,9 @@ 1e-6 ) +print("Output:", output) +print("Grad Input:", grad_input) +print("Grad Weight:", grad_weight) +print("Grad Bias:", grad_bias) b = input * torch.rsqrt(input.pow(2).mean(-1, keepdim=True) + 1e-6) * weight assert torch.allclose(output, b) From d675c9f9731c23fed1ae7ac63575b54b99105a62 Mon Sep 17 00:00:00 2001 From: Zhangzefeng Date: Wed, 27 Mar 2024 10:08:18 +0800 Subject: [PATCH 04/39] Update extensions.cpp --- csrc/extensions.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/csrc/extensions.cpp b/csrc/extensions.cpp index d3b32ada..79e02ad5 100644 --- a/csrc/extensions.cpp +++ b/csrc/extensions.cpp @@ -3,7 +3,6 @@ #include #include #include -#include #include #include From 3f100c019197a4023f12c213cbb6af985afd2211 Mon Sep 17 00:00:00 2001 From: zhangzefeng92 Date: Wed, 27 Mar 2024 10:11:53 +0800 Subject: [PATCH 05/39] modify extensions.cpp --- csrc/extensions.cpp | 43 +++++++++++++------------------------------ 1 file changed, 13 insertions(+), 30 deletions(-) diff --git a/csrc/extensions.cpp b/csrc/extensions.cpp index 79e02ad5..93e87430 100644 --- a/csrc/extensions.cpp +++ b/csrc/extensions.cpp @@ -26,46 +26,29 @@ namespace dipu::dipu_ext { - -auto extRmsNorm(at::Tensor& output, - at::Tensor& inv_rms, +auto extRmsNorm(at::Tensor& output, at::Tensor& inv_rms, const at::Tensor& input, const OptionalIntArray& normalized_shape, - const at::Tensor& weight, - const at::Tensor& bias, - double eps) { + const at::Tensor& weight, const at::Tensor& bias, double eps) { at::OptionalIntArrayRef normalized_shape_at = *normalized_shape; - callDiopi(diopiRMSNorm, output, inv_rms, input, normalized_shape_at, weight, bias, eps); + callDiopi(diopiRMSNorm, output, inv_rms, input, normalized_shape_at, weight, + bias, eps); return std::make_tuple(std::move(output), std::move(inv_rms)); } - -auto extRmsNormBackward(at::Tensor& grad_input, - at::Tensor& grad_weight, - at::Tensor& grad_bias, - const at::Tensor& grad_output, - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& bias, - const at::Tensor& inv_rms, - const OptionalIntArray& normalized_shape, - double eps) { +auto extRmsNormBackward(at::Tensor& grad_input, at::Tensor& grad_weight, + at::Tensor& grad_bias, const at::Tensor& grad_output, + const at::Tensor& input, const at::Tensor& weight, + const at::Tensor& bias, const at::Tensor& inv_rms, + const OptionalIntArray& normalized_shape, double eps) { at::OptionalIntArrayRef normalized_shape_at = *normalized_shape; - callDiopi(diopiRMSNormBackward, - grad_input, - grad_weight, - grad_bias, - grad_output, - input, - weight, - bias, - inv_rms, - normalized_shape_at, + callDiopi(diopiRMSNormBackward, grad_input, grad_weight, grad_bias, + grad_output, input, weight, bias, inv_rms, normalized_shape_at, eps); - return std::make_tuple(std::move(grad_input), std::move(grad_weight), std::move(grad_bias)); + return std::make_tuple(std::move(grad_input), std::move(grad_weight), + std::move(grad_bias)); } - void extApplyRotary(at::Tensor output, const at::Tensor& input, const at::Tensor& cos, const at::Tensor& sin, const bool conj, const bool interleaved) { From fce081f9e365f2e6b130c39e09592af82267d489 Mon Sep 17 00:00:00 2001 From: zhangzefeng92 Date: Wed, 27 Mar 2024 10:19:14 +0800 Subject: [PATCH 06/39] fix python lint --- .../internlm_ops/rms_norm/deeplink.py | 57 +++++++++++++------ tests/test_rms_lightlm.py | 2 +- 2 files changed, 42 insertions(+), 17 deletions(-) diff --git a/deeplink_ext/internlm_ops/rms_norm/deeplink.py b/deeplink_ext/internlm_ops/rms_norm/deeplink.py index f9342fc0..82142e17 100644 --- a/deeplink_ext/internlm_ops/rms_norm/deeplink.py +++ b/deeplink_ext/internlm_ops/rms_norm/deeplink.py @@ -12,7 +12,9 @@ class _DeepLinkRMSNormFunction(torch.autograd.Function): def forward(ctx, hidden_states, weight, bias, eps): output = torch.empty_like(hidden_states) inv_rms_shape = list(hidden_states.shape[:-1], 1) - inv_rms = torch.empty(inv_rms_shape, dtype=hidden_states.dtype, device=hidden_states.device) + inv_rms = torch.empty( + inv_rms_shape, dtype=hidden_states.dtype, device=hidden_states.device + ) ext.rms_norm(output, inv_rms, hidden_states, None, weight, bias, eps) ctx.save_for_backward(hidden_states, inv_rms, weight, bias, torch.tensor(eps)) @@ -22,12 +24,23 @@ def forward(ctx, hidden_states, weight, bias, eps): def backward(ctx, grad_output): hidden_states, inv_rms, weight, bias, eps_tensor = ctx.saved_tensors eps = eps_tensor.item() - + grad_input = torch.empty_like(hidden_states) grad_weight = torch.empty_like(weight) grad_bias = torch.empty_like(bias) - ext.rms_norm_backward(grad_input, grad_weight, grad_bias, grad_output, hidden_states, weight, bias, inv_rms, None, eps) + ext.rms_norm_backward( + grad_input, + grad_weight, + grad_bias, + grad_output, + hidden_states, + weight, + bias, + inv_rms, + None, + eps + ) return grad_input, grad_weight, grad_bias, None @@ -36,8 +49,18 @@ class _DeepLinkRMSNormFunctionWithNormalizedShape(torch.autograd.Function): def forward(ctx, hidden_states, weight, bias, eps, normalized_shape): output = torch.empty_like(hidden_states, dtype=torch.float32) inv_rms_shape = list(hidden_states.shape[:-1], 1) - inv_rms = torch.empty(inv_rms_shape, dtype=torch.float32, device=hidden_states.device) - ext.rms_norm(output, inv_rms, hidden_states.float(), normalized_shape, weight.float(), bias.float(), eps) + inv_rms = torch.empty( + inv_rms_shape, dtype=torch.float32, device=hidden_states.device + ) + ext.rms_norm( + output, + inv_rms, + hidden_states.float(), + normalized_shape, + weight.float(), + bias.float(), + eps + ) output = output.half() inv_rms = inv_rms.half() ctx.save_for_backward(hidden_states, inv_rms, weight, bias, torch.tensor(eps)) @@ -49,20 +72,22 @@ def backward(ctx, grad_output): hidden_states, inv_rms, weight, bias, eps_tensor = ctx.saved_tensors eps = eps_tensor.item() normalized_shape = ctx.intermediate_results - + grad_input = torch.empty_like(hidden_states, dtype=torch.float32) grad_weight = torch.empty_like(weight, dtype=torch.float32) grad_bias = torch.empty_like(bias, dtype=torch.float32) - ext.rms_norm_backward(grad_input, - grad_weight, - grad_bias, - grad_output.float(), - hidden_states.float(), - weight.float(), - bias.float(), - inv_rms.float(), - normalized_shape, - eps) + ext.rms_norm_backward( + grad_input, + grad_weight, + grad_bias, + grad_output.float(), + hidden_states.float(), + weight.float(), + bias.float(), + inv_rms.float(), + normalized_shape, + eps + ) grad_output = grad_output.half() hidden_states = hidden_states.half() inv_rms = inv_rms.half() diff --git a/tests/test_rms_lightlm.py b/tests/test_rms_lightlm.py index 35915e2d..31f081fa 100644 --- a/tests/test_rms_lightlm.py +++ b/tests/test_rms_lightlm.py @@ -38,7 +38,7 @@ bias, inv_rms, weight.shape, - 1e-6 + 1e-6, ) print("Output:", output) From 7e59912ba48046b8bfdd2a6f4fa4ed825858592a Mon Sep 17 00:00:00 2001 From: zhangzefeng92 Date: Wed, 27 Mar 2024 10:20:39 +0800 Subject: [PATCH 07/39] fix python lint --- deeplink_ext/internlm_ops/rms_norm/deeplink.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deeplink_ext/internlm_ops/rms_norm/deeplink.py b/deeplink_ext/internlm_ops/rms_norm/deeplink.py index 82142e17..793bc02e 100644 --- a/deeplink_ext/internlm_ops/rms_norm/deeplink.py +++ b/deeplink_ext/internlm_ops/rms_norm/deeplink.py @@ -28,7 +28,7 @@ def backward(ctx, grad_output): grad_input = torch.empty_like(hidden_states) grad_weight = torch.empty_like(weight) grad_bias = torch.empty_like(bias) - + ext.rms_norm_backward( grad_input, grad_weight, From bd74cb0eda6b9515e40c138e0efb944fd1ad689e Mon Sep 17 00:00:00 2001 From: zhangzefeng92 Date: Wed, 27 Mar 2024 10:21:47 +0800 Subject: [PATCH 08/39] fix python lint --- deeplink_ext/internlm_ops/rms_norm/deeplink.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/deeplink_ext/internlm_ops/rms_norm/deeplink.py b/deeplink_ext/internlm_ops/rms_norm/deeplink.py index 793bc02e..9403340d 100644 --- a/deeplink_ext/internlm_ops/rms_norm/deeplink.py +++ b/deeplink_ext/internlm_ops/rms_norm/deeplink.py @@ -39,7 +39,7 @@ def backward(ctx, grad_output): bias, inv_rms, None, - eps + eps, ) return grad_input, grad_weight, grad_bias, None @@ -59,7 +59,7 @@ def forward(ctx, hidden_states, weight, bias, eps, normalized_shape): normalized_shape, weight.float(), bias.float(), - eps + eps, ) output = output.half() inv_rms = inv_rms.half() @@ -86,7 +86,7 @@ def backward(ctx, grad_output): bias.float(), inv_rms.float(), normalized_shape, - eps + eps, ) grad_output = grad_output.half() hidden_states = hidden_states.half() From 5027589476fa22b9fc47af67f4fa50c985a488c8 Mon Sep 17 00:00:00 2001 From: zhangzefeng92 Date: Wed, 27 Mar 2024 10:23:02 +0800 Subject: [PATCH 09/39] fix python lint --- tests/test_rms_lightlm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_rms_lightlm.py b/tests/test_rms_lightlm.py index 31f081fa..a5ee2baa 100644 --- a/tests/test_rms_lightlm.py +++ b/tests/test_rms_lightlm.py @@ -23,7 +23,6 @@ ext.rms_norm(output, inv_rms, input, weight.shape, weight, bias, 1e-6) - # 使用 RMS normalization 反向传播 grad_input = torch.empty_like(grad_output) grad_weight = torch.empty_like(weight) From 7fc26c7e55afffe0cabe97bcbadaa0386ead0186 Mon Sep 17 00:00:00 2001 From: zhangzefeng92 Date: Wed, 27 Mar 2024 11:07:28 +0800 Subject: [PATCH 10/39] fix rms norm --- deeplink_ext/internlm_ops/rms_norm/deeplink.py | 9 ++++----- tests/test_rms_lightlm.py | 9 ++++++++- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/deeplink_ext/internlm_ops/rms_norm/deeplink.py b/deeplink_ext/internlm_ops/rms_norm/deeplink.py index 9403340d..d42208bc 100644 --- a/deeplink_ext/internlm_ops/rms_norm/deeplink.py +++ b/deeplink_ext/internlm_ops/rms_norm/deeplink.py @@ -11,11 +11,11 @@ class _DeepLinkRMSNormFunction(torch.autograd.Function): @staticmethod def forward(ctx, hidden_states, weight, bias, eps): output = torch.empty_like(hidden_states) - inv_rms_shape = list(hidden_states.shape[:-1], 1) + inv_rms_shape = list(hidden_states.shape[:-1]) + [1] inv_rms = torch.empty( inv_rms_shape, dtype=hidden_states.dtype, device=hidden_states.device ) - ext.rms_norm(output, inv_rms, hidden_states, None, weight, bias, eps) + ext.rms_norm(output, inv_rms, hidden_states, weight.shape, weight, bias, eps) ctx.save_for_backward(hidden_states, inv_rms, weight, bias, torch.tensor(eps)) return output @@ -28,7 +28,6 @@ def backward(ctx, grad_output): grad_input = torch.empty_like(hidden_states) grad_weight = torch.empty_like(weight) grad_bias = torch.empty_like(bias) - ext.rms_norm_backward( grad_input, grad_weight, @@ -38,7 +37,7 @@ def backward(ctx, grad_output): weight, bias, inv_rms, - None, + weight.shape, eps, ) return grad_input, grad_weight, grad_bias, None @@ -48,7 +47,7 @@ class _DeepLinkRMSNormFunctionWithNormalizedShape(torch.autograd.Function): @staticmethod def forward(ctx, hidden_states, weight, bias, eps, normalized_shape): output = torch.empty_like(hidden_states, dtype=torch.float32) - inv_rms_shape = list(hidden_states.shape[:-1], 1) + inv_rms_shape = list(hidden_states.shape[:-1]) + [1] inv_rms = torch.empty( inv_rms_shape, dtype=torch.float32, device=hidden_states.device ) diff --git a/tests/test_rms_lightlm.py b/tests/test_rms_lightlm.py index a5ee2baa..0e6b911c 100644 --- a/tests/test_rms_lightlm.py +++ b/tests/test_rms_lightlm.py @@ -22,7 +22,6 @@ inv_rms = torch.empty(inv_rms_shape, dtype=torch.float32, device=input.device) ext.rms_norm(output, inv_rms, input, weight.shape, weight, bias, 1e-6) - # 使用 RMS normalization 反向传播 grad_input = torch.empty_like(grad_output) grad_weight = torch.empty_like(weight) @@ -44,5 +43,13 @@ print("Grad Input:", grad_input) print("Grad Weight:", grad_weight) print("Grad Bias:", grad_bias) + +input.requires_grad_(True) +weight.requires_grad_(True) +bias.requires_grad_(True) b = input * torch.rsqrt(input.pow(2).mean(-1, keepdim=True) + 1e-6) * weight +grads = torch.autograd.grad(b, [input, weight, bias], grad_output, allow_unused=True) assert torch.allclose(output, b) +assert torch.allclose(grad_input, grads[0]) +assert torch.allclose(grad_weight, grads[1]) +# assert torch.allclose(grad_bias, grads[2]) From 136e8e18db8fc531d151ba13d948e68aabdb990f Mon Sep 17 00:00:00 2001 From: zhangzefeng92 Date: Thu, 28 Mar 2024 09:21:47 +0800 Subject: [PATCH 11/39] modify rms norm --- csrc/extensions.cpp | 18 --- csrc/pybind_type_cast.h | 24 ---- .../internlm_ops/rms_norm/deeplink.py | 135 +++++++++++------- deeplink_ext/patch_lightllm.py | 13 +- tests/test_rms_internlm.py | 6 +- tests/test_rms_lightlm.py | 26 +--- 6 files changed, 106 insertions(+), 116 deletions(-) diff --git a/csrc/extensions.cpp b/csrc/extensions.cpp index 93e87430..370b163c 100644 --- a/csrc/extensions.cpp +++ b/csrc/extensions.cpp @@ -197,21 +197,6 @@ void extApplyPenalty(at::Tensor& logits, const at::Tensor& presence_penalty, p_token_ids, p_token_counts, p_cumsum_seq_len, p_max_len_in_batch); } -// For lightllm, rms_norm reuses the diopi implementation of internlm -auto extRmsNormLightllm(const at::Tensor& x, const at::Tensor& weight, - float eps) { - at::ScalarType acc_type = x.scalar_type(); - if (x.scalar_type() == at::kBFloat16 || x.scalar_type() == at::kHalf) { - acc_type = at::kFloat; - } - auto inv_rms = at::empty_like(x, acc_type); - auto out = at::empty_like(x); - auto bias = at::empty_like(weight); - at::OptionalIntArrayRef normalized_shape = weight.sizes(); - callDiopi(diopiRMSNorm, out, inv_rms, x, normalized_shape, weight, bias, eps); - return out; -} - // For lightllm, rotary_embedding reuses the diopi implementation of internlm void extRotaryEmb(at::Tensor& q, const at::Tensor& cos, const at::Tensor& sin) { auto seq_len = q.size(0); @@ -229,9 +214,6 @@ void extRotaryEmb(at::Tensor& q, const at::Tensor& cos, const at::Tensor& sin) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { if (&diopiRMSNorm != nullptr) { // Check if weak symbol defined m.def("rms_norm", &extRmsNorm, "deeplink ext_rms_norm"); - m.def("rms_norm_lightllm", &extRmsNormLightllm, - "deeplink ext_rms_norm for lightllm", py::arg("x"), py::arg("weight"), - py::arg("eps")); } if (&diopiRMSNormBackward != nullptr) { m.def("rms_norm_backward", &extRmsNormBackward, diff --git a/csrc/pybind_type_cast.h b/csrc/pybind_type_cast.h index 6d128981..61e8f484 100644 --- a/csrc/pybind_type_cast.h +++ b/csrc/pybind_type_cast.h @@ -21,28 +21,4 @@ using OptionalIntArray = c10::optional; } // namespace dipu::dipu_ext -namespace pybind11::detail { - -namespace py = pybind11; - -template <> -struct type_caster { - public: - PYBIND11_TYPE_CASTER(dipu::dipu_ext::OptionalIntArray, _("OptionalIntArray")); - - bool load(py::handle src, bool /*unused*/) { - if (PyList_Check(src.ptr())) { - value = py::cast(src); - return true; - } - if (src.is_none()) { - value = c10::nullopt; - return true; - } - return false; - } -}; - -} // namespace pybind11::detail - #endif /* end of include guard: PYBIND_TYPE_CAST_H_PXMGELYW */ diff --git a/deeplink_ext/internlm_ops/rms_norm/deeplink.py b/deeplink_ext/internlm_ops/rms_norm/deeplink.py index d42208bc..e6841775 100644 --- a/deeplink_ext/internlm_ops/rms_norm/deeplink.py +++ b/deeplink_ext/internlm_ops/rms_norm/deeplink.py @@ -1,64 +1,110 @@ # Copyright (c) 2024, DeepLink. import torch -import deeplink_ext.cpp_extensions as ext +import deeplink_ext.cpp_extensions as cpp_ext -assert hasattr(ext, "rms_norm") +assert hasattr(cpp_ext, "rms_norm") -# 定义自定义的 autograd 函数 -class _DeepLinkRMSNormFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, hidden_states, weight, bias, eps): - output = torch.empty_like(hidden_states) - inv_rms_shape = list(hidden_states.shape[:-1]) + [1] - inv_rms = torch.empty( - inv_rms_shape, dtype=hidden_states.dtype, device=hidden_states.device - ) - ext.rms_norm(output, inv_rms, hidden_states, weight.shape, weight, bias, eps) +def rms_norm_out(output, inv_rms, input, normalized_shape, weight, bias, eps): + if None == normalized_shape: + cpp_ext.rms_norm(output, inv_rms, input, weight.shape, weight, bias, eps) + else: + cpp_ext.rms_norm(output, inv_rms, input, normalized_shape, weight, bias, eps) - ctx.save_for_backward(hidden_states, inv_rms, weight, bias, torch.tensor(eps)) - return output - @staticmethod - def backward(ctx, grad_output): - hidden_states, inv_rms, weight, bias, eps_tensor = ctx.saved_tensors - eps = eps_tensor.item() +def rms_norm(input, normalized_shape, weight, bias, eps): + output = torch.empty_like(input) + inv_rms_shape = list(input.shape[:-1]) + [1] + inv_rms = torch.empty(inv_rms_shape, dtype=input.dtype, device=input.device) + rms_norm_out(output, inv_rms, input, normalized_shape, weight, bias, eps) + + return [output, inv_rms] - grad_input = torch.empty_like(hidden_states) - grad_weight = torch.empty_like(weight) - grad_bias = torch.empty_like(bias) - ext.rms_norm_backward( + +def rms_norm_backward_out( + grad_input, + grad_weight, + grad_bias, + grad_output, + input, + weight, + bias, + inv_rms, + normalized_shape, + eps, +): + if None == normalized_shape: + cpp_ext.rms_norm_backward( grad_input, grad_weight, grad_bias, grad_output, - hidden_states, + input, weight, bias, inv_rms, weight.shape, eps, ) + else: + cpp_ext.rms_norm_backward( + grad_input, + grad_weight, + grad_bias, + grad_output, + input, + weight, + bias, + inv_rms, + normalized_shape, + eps, + ) + + +def rms_norm_backward(input, grad_output, inv_rms, normalized_shape, weight, bias, eps): + grad_input = torch.empty_like(input) + grad_weight = torch.empty_like(weight) + grad_bias = torch.empty_like(bias) + rms_norm_backward_out( + grad_input, + grad_weight, + grad_bias, + grad_output, + input, + weight, + bias, + inv_rms, + normalized_shape, + eps, + ) + + return [grad_input, grad_weight, grad_bias] + + +# 定义自定义的 autograd 函数 +class _DeepLinkRMSNormFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, hidden_states, weight, bias, eps): + output, inv_rms = rms_norm(hidden_states, None, weight, bias, eps) + ctx.save_for_backward(hidden_states, inv_rms, weight, bias, torch.tensor(eps)) + return output + + @staticmethod + def backward(ctx, grad_output): + hidden_states, inv_rms, weight, bias, eps_tensor = ctx.saved_tensors + eps = eps_tensor.item() + grad_input, grad_weight, grad_bias = rms_norm_backward( + hidden_states, grad_output, inv_rms, None, weight, bias, eps + ) return grad_input, grad_weight, grad_bias, None class _DeepLinkRMSNormFunctionWithNormalizedShape(torch.autograd.Function): @staticmethod def forward(ctx, hidden_states, weight, bias, eps, normalized_shape): - output = torch.empty_like(hidden_states, dtype=torch.float32) - inv_rms_shape = list(hidden_states.shape[:-1]) + [1] - inv_rms = torch.empty( - inv_rms_shape, dtype=torch.float32, device=hidden_states.device - ) - ext.rms_norm( - output, - inv_rms, - hidden_states.float(), - normalized_shape, - weight.float(), - bias.float(), - eps, + output, inv_rms = rms_norm( + hidden_states.float(), normalized_shape, weight.float(), bias.float(), eps ) output = output.half() inv_rms = inv_rms.half() @@ -72,24 +118,15 @@ def backward(ctx, grad_output): eps = eps_tensor.item() normalized_shape = ctx.intermediate_results - grad_input = torch.empty_like(hidden_states, dtype=torch.float32) - grad_weight = torch.empty_like(weight, dtype=torch.float32) - grad_bias = torch.empty_like(bias, dtype=torch.float32) - ext.rms_norm_backward( - grad_input, - grad_weight, - grad_bias, - grad_output.float(), + grad_input, grad_weight, grad_bias = rms_norm_backward( hidden_states.float(), - weight.float(), - bias.float(), + grad_output.float(), inv_rms.float(), normalized_shape, + weight.float(), + bias.float(), eps, ) - grad_output = grad_output.half() - hidden_states = hidden_states.half() - inv_rms = inv_rms.half() return grad_input, grad_weight, grad_bias, None, None diff --git a/deeplink_ext/patch_lightllm.py b/deeplink_ext/patch_lightllm.py index d371870a..4aeee369 100644 --- a/deeplink_ext/patch_lightllm.py +++ b/deeplink_ext/patch_lightllm.py @@ -51,7 +51,18 @@ def patch_token_softmax_reducev_inference(): ) def patch_rms_norm_lightllm(): - rms_norm_pack.rmsnorm_forward = ext.rms_norm_lightllm + import torch + + def rms_norm_lightllm(x, weight, eps): + output = torch.empty_like(x) + inv_rms_dtype = torch.float16 if x.dtype == torch.bfloat16 else x.dtype + inv_rms = torch.empty_like(x, dtype=inv_rms_dtype) + bias = torch.empty_like(weight) + ext.rms_norm(output, inv_rms, x, weight.shape, weight, bias, eps) + + return output + + rms_norm_pack.rmsnorm_forward = rms_norm_lightllm def patch_rotary_emb(): rotary_emb_pack.rotary_emb_fwd = ext.rotary_emb diff --git a/tests/test_rms_internlm.py b/tests/test_rms_internlm.py index 72bca9ef..8d1a7807 100644 --- a/tests/test_rms_internlm.py +++ b/tests/test_rms_internlm.py @@ -5,7 +5,7 @@ import deeplink_ext.internlm_ops.rms_norm as ext -def test_rms_norm(BaseRmsNorm, DeeplinkRmsNorm, rtol=1e-4, atol=1e-3): +def rms_norm_test(BaseRmsNorm, DeeplinkRmsNorm, rtol=1e-4, atol=1e-3): x_base = torch.randn(5, 5, requires_grad=True).cuda() x_base.retain_grad() @@ -29,9 +29,9 @@ def test_rms_norm(BaseRmsNorm, DeeplinkRmsNorm, rtol=1e-4, atol=1e-3): print( "Test case: normalized_shape == None: grad_inputs closed ? ", - test_rms_norm(ext.fallback.RMSNorm, ext.DeepLinkRMSNorm), + rms_norm_test(ext.fallback.RMSNorm, ext.DeepLinkRMSNorm), ) print( "Test case: normalized_shape == weight.size(): grad_inputs closed ? ", - test_rms_norm(ext.fallback.RMSNorm, ext.DeepLinkRMSNormWithNormalizedShape), + rms_norm_test(ext.fallback.RMSNorm, ext.DeepLinkRMSNormWithNormalizedShape), ) diff --git a/tests/test_rms_lightlm.py b/tests/test_rms_lightlm.py index 0e6b911c..ba57369b 100644 --- a/tests/test_rms_lightlm.py +++ b/tests/test_rms_lightlm.py @@ -1,7 +1,7 @@ # Copyright (c) 2023, DeepLink. import torch -import deeplink_ext.cpp_extensions as ext +from deeplink_ext.internlm_ops.rms_norm.deeplink import rms_norm, rms_norm_backward # 定义输入张量 input = torch.randn(5, 5, requires_grad=True).cuda() @@ -17,26 +17,10 @@ normalized_shape = torch.tensor([5, 5], dtype=torch.long).cuda() print(input.is_dipu) -output = torch.empty_like(input) -inv_rms_shape = list(input.shape[:-1]) + [1] -inv_rms = torch.empty(inv_rms_shape, dtype=torch.float32, device=input.device) -ext.rms_norm(output, inv_rms, input, weight.shape, weight, bias, 1e-6) - -# 使用 RMS normalization 反向传播 -grad_input = torch.empty_like(grad_output) -grad_weight = torch.empty_like(weight) -grad_bias = torch.empty_like(bias) -ext.rms_norm_backward( - grad_input, - grad_weight, - grad_bias, - grad_output, - input, - weight, - bias, - inv_rms, - weight.shape, - 1e-6, +output, inv_rms = rms_norm(input, None, weight, bias, 1e-6) + +grad_input, grad_weight, grad_bias = rms_norm_backward( + input, grad_output, inv_rms, None, weight, bias, 1e-6 ) print("Output:", output) From 9277e0adec4de0bfe36a1b4e8ca3694ddece04e5 Mon Sep 17 00:00:00 2001 From: zhangzefeng92 Date: Tue, 26 Mar 2024 17:05:10 +0800 Subject: [PATCH 12/39] update cpp ext of rms norm --- csrc/extensions.cpp | 62 +++++++++---------- .../internlm_ops/rms_norm/deeplink.py | 48 ++++++++------ tests/test_rms_lightlm.py | 36 ++++++++--- 3 files changed, 85 insertions(+), 61 deletions(-) diff --git a/csrc/extensions.cpp b/csrc/extensions.cpp index 8b4c6389..d3b32ada 100644 --- a/csrc/extensions.cpp +++ b/csrc/extensions.cpp @@ -27,50 +27,46 @@ namespace dipu::dipu_ext { -namespace { -at::IntArrayRef optionalIntArrayToIntArrayRefOrDefault( - const OptionalIntArray& opt, at::IntArrayRef def) { - if (opt) { - return {*opt}; - } - return def; -} - -} // namespace - -auto extRmsNorm(const at::Tensor& input, +auto extRmsNorm(at::Tensor& output, + at::Tensor& inv_rms, + const at::Tensor& input, const OptionalIntArray& normalized_shape, - const at::Tensor& weight, const at::Tensor& bias, double eps) { - at::OptionalIntArrayRef normalized_shape_at = - optionalIntArrayToIntArrayRefOrDefault(normalized_shape, weight.sizes()); - auto input_shape = input.sizes(); - std::vector input_size(input_shape.begin(), input_shape.end()); - input_size.back() = 1; - auto inv_rms = at::empty(input_size, input.options()); - auto output = at::empty_like(input); - callDiopi(diopiRMSNorm, output, inv_rms, input, normalized_shape_at, weight, - bias, eps); + const at::Tensor& weight, + const at::Tensor& bias, + double eps) { + at::OptionalIntArrayRef normalized_shape_at = *normalized_shape; + callDiopi(diopiRMSNorm, output, inv_rms, input, normalized_shape_at, weight, bias, eps); return std::make_tuple(std::move(output), std::move(inv_rms)); } -auto extRmsNormBackward(const at::Tensor& input, const at::Tensor& grad_output, + +auto extRmsNormBackward(at::Tensor& grad_input, + at::Tensor& grad_weight, + at::Tensor& grad_bias, + const at::Tensor& grad_output, + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& bias, const at::Tensor& inv_rms, const OptionalIntArray& normalized_shape, - const at::Tensor& weight, const at::Tensor& bias, double eps) { - at::OptionalIntArrayRef normalized_shape_at = - optionalIntArrayToIntArrayRefOrDefault(normalized_shape, weight.sizes()); - auto grad_input = at::empty_like(grad_output); - auto grad_weight = at::empty_like(weight); - auto grad_bias = at::empty_like(bias); - callDiopi(diopiRMSNormBackward, grad_input, grad_weight, grad_bias, - grad_output, input, weight, bias, inv_rms, normalized_shape_at, + at::OptionalIntArrayRef normalized_shape_at = *normalized_shape; + callDiopi(diopiRMSNormBackward, + grad_input, + grad_weight, + grad_bias, + grad_output, + input, + weight, + bias, + inv_rms, + normalized_shape_at, eps); - return std::make_tuple(std::move(grad_input), std::move(grad_weight), - std::move(grad_bias)); + return std::make_tuple(std::move(grad_input), std::move(grad_weight), std::move(grad_bias)); } + void extApplyRotary(at::Tensor output, const at::Tensor& input, const at::Tensor& cos, const at::Tensor& sin, const bool conj, const bool interleaved) { diff --git a/deeplink_ext/internlm_ops/rms_norm/deeplink.py b/deeplink_ext/internlm_ops/rms_norm/deeplink.py index 406dc4aa..f9342fc0 100644 --- a/deeplink_ext/internlm_ops/rms_norm/deeplink.py +++ b/deeplink_ext/internlm_ops/rms_norm/deeplink.py @@ -10,7 +10,10 @@ class _DeepLinkRMSNormFunction(torch.autograd.Function): @staticmethod def forward(ctx, hidden_states, weight, bias, eps): - output, inv_rms = ext.rms_norm(hidden_states, None, weight, bias, eps) + output = torch.empty_like(hidden_states) + inv_rms_shape = list(hidden_states.shape[:-1], 1) + inv_rms = torch.empty(inv_rms_shape, dtype=hidden_states.dtype, device=hidden_states.device) + ext.rms_norm(output, inv_rms, hidden_states, None, weight, bias, eps) ctx.save_for_backward(hidden_states, inv_rms, weight, bias, torch.tensor(eps)) return output @@ -19,24 +22,25 @@ def forward(ctx, hidden_states, weight, bias, eps): def backward(ctx, grad_output): hidden_states, inv_rms, weight, bias, eps_tensor = ctx.saved_tensors eps = eps_tensor.item() - grad_input, grad_weight, grad_bias = ext.rms_norm_backward( - hidden_states, grad_output, inv_rms, None, weight, bias, eps - ) + + grad_input = torch.empty_like(hidden_states) + grad_weight = torch.empty_like(weight) + grad_bias = torch.empty_like(bias) + + ext.rms_norm_backward(grad_input, grad_weight, grad_bias, grad_output, hidden_states, weight, bias, inv_rms, None, eps) return grad_input, grad_weight, grad_bias, None class _DeepLinkRMSNormFunctionWithNormalizedShape(torch.autograd.Function): @staticmethod def forward(ctx, hidden_states, weight, bias, eps, normalized_shape): - output, inv_rms = ext.rms_norm( - hidden_states.float(), normalized_shape, weight.float(), bias.float(), eps - ) + output = torch.empty_like(hidden_states, dtype=torch.float32) + inv_rms_shape = list(hidden_states.shape[:-1], 1) + inv_rms = torch.empty(inv_rms_shape, dtype=torch.float32, device=hidden_states.device) + ext.rms_norm(output, inv_rms, hidden_states.float(), normalized_shape, weight.float(), bias.float(), eps) output = output.half() inv_rms = inv_rms.half() ctx.save_for_backward(hidden_states, inv_rms, weight, bias, torch.tensor(eps)) - hidden_states = hidden_states.half() - weight = weight.half() - bias = bias.half() ctx.intermediate_results = normalized_shape return output @@ -45,19 +49,23 @@ def backward(ctx, grad_output): hidden_states, inv_rms, weight, bias, eps_tensor = ctx.saved_tensors eps = eps_tensor.item() normalized_shape = ctx.intermediate_results - hidden_states = hidden_states.float() - inv_rms = inv_rms.float() - weight = weight.float() - bias = bias.float() - grad_output = grad_output.float() - grad_input, grad_weight, grad_bias = ext.rms_norm_backward( - hidden_states, grad_output, inv_rms, normalized_shape, weight, bias, eps - ) + + grad_input = torch.empty_like(hidden_states, dtype=torch.float32) + grad_weight = torch.empty_like(weight, dtype=torch.float32) + grad_bias = torch.empty_like(bias, dtype=torch.float32) + ext.rms_norm_backward(grad_input, + grad_weight, + grad_bias, + grad_output.float(), + hidden_states.float(), + weight.float(), + bias.float(), + inv_rms.float(), + normalized_shape, + eps) grad_output = grad_output.half() hidden_states = hidden_states.half() inv_rms = inv_rms.half() - weight = weight.half() - bias = bias.half() return grad_input, grad_weight, grad_bias, None, None diff --git a/tests/test_rms_lightlm.py b/tests/test_rms_lightlm.py index ea9f66b3..c3bec42a 100644 --- a/tests/test_rms_lightlm.py +++ b/tests/test_rms_lightlm.py @@ -17,16 +17,36 @@ normalized_shape = torch.tensor([5, 5], dtype=torch.long).cuda() print(input.is_dipu) -output, inv_rms = ext.rms_norm(input, None, weight, bias, 1e-6) +# output, inv_rms = ext.rms_norm(input, None, weight, bias, 1e-6) +# output1, inv_rms1 = ext.rms_norm(input, weight.shape, weight, bias, 1e-6) + +output = torch.empty_like(input) +inv_rms_shape = list(input.shape[:-1]) + [1] +inv_rms = torch.empty(inv_rms_shape, dtype=torch.float32, device=input.device) +ext.rms_norm(output, inv_rms, input, weight.shape, weight, bias, 1e-6) + + # 使用 RMS normalization 反向传播 -grad_input, grad_weight, grad_bias = ext.rms_norm_backward( - input, grad_output, inv_rms, None, weight, bias, 1e-6 +grad_input = torch.empty_like(grad_output) +grad_weight = torch.empty_like(weight) +grad_bias = torch.empty_like(bias) +ext.rms_norm_backward( + grad_input, + grad_weight, + grad_bias, + grad_output, + input, + weight, + bias, + inv_rms, + weight.shape, + 1e-6 ) -print("Output:", output) -print("Grad Input:", grad_input) -print("Grad Weight:", grad_weight) -print("Grad Bias:", grad_bias) +# print("Output:", output) +# print("Grad Input:", grad_input) +# print("Grad Weight:", grad_weight) +# print("Grad Bias:", grad_bias) b = input * torch.rsqrt(input.pow(2).mean(-1, keepdim=True) + 1e-6) * weight -assert torch.allclose(output, b) +assert torch.allclose(output, b) \ No newline at end of file From 1be896369f7a025342af32f08cc8d744ec954930 Mon Sep 17 00:00:00 2001 From: Zhangzefeng Date: Wed, 27 Mar 2024 09:08:52 +0800 Subject: [PATCH 13/39] Update test_rms_lightlm.py --- tests/test_rms_lightlm.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/tests/test_rms_lightlm.py b/tests/test_rms_lightlm.py index c3bec42a..316b3c9d 100644 --- a/tests/test_rms_lightlm.py +++ b/tests/test_rms_lightlm.py @@ -17,9 +17,6 @@ normalized_shape = torch.tensor([5, 5], dtype=torch.long).cuda() print(input.is_dipu) -# output, inv_rms = ext.rms_norm(input, None, weight, bias, 1e-6) -# output1, inv_rms1 = ext.rms_norm(input, weight.shape, weight, bias, 1e-6) - output = torch.empty_like(input) inv_rms_shape = list(input.shape[:-1]) + [1] inv_rms = torch.empty(inv_rms_shape, dtype=torch.float32, device=input.device) @@ -44,9 +41,5 @@ 1e-6 ) -# print("Output:", output) -# print("Grad Input:", grad_input) -# print("Grad Weight:", grad_weight) -# print("Grad Bias:", grad_bias) b = input * torch.rsqrt(input.pow(2).mean(-1, keepdim=True) + 1e-6) * weight -assert torch.allclose(output, b) \ No newline at end of file +assert torch.allclose(output, b) From 6cd9db6e950bd5a1617ed5c197be3438882c4fa5 Mon Sep 17 00:00:00 2001 From: Zhangzefeng Date: Wed, 27 Mar 2024 09:09:47 +0800 Subject: [PATCH 14/39] Update test_rms_lightlm.py --- tests/test_rms_lightlm.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_rms_lightlm.py b/tests/test_rms_lightlm.py index 316b3c9d..35915e2d 100644 --- a/tests/test_rms_lightlm.py +++ b/tests/test_rms_lightlm.py @@ -41,5 +41,9 @@ 1e-6 ) +print("Output:", output) +print("Grad Input:", grad_input) +print("Grad Weight:", grad_weight) +print("Grad Bias:", grad_bias) b = input * torch.rsqrt(input.pow(2).mean(-1, keepdim=True) + 1e-6) * weight assert torch.allclose(output, b) From fd486a6e73bb1376dafe9d5da5cc9befeabdc9ad Mon Sep 17 00:00:00 2001 From: Zhangzefeng Date: Wed, 27 Mar 2024 10:08:18 +0800 Subject: [PATCH 15/39] Update extensions.cpp --- csrc/extensions.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/csrc/extensions.cpp b/csrc/extensions.cpp index d3b32ada..79e02ad5 100644 --- a/csrc/extensions.cpp +++ b/csrc/extensions.cpp @@ -3,7 +3,6 @@ #include #include #include -#include #include #include From d54413b4a62fda60557ca1cc29f248568b5e821d Mon Sep 17 00:00:00 2001 From: zhangzefeng92 Date: Wed, 27 Mar 2024 10:11:53 +0800 Subject: [PATCH 16/39] modify extensions.cpp --- csrc/extensions.cpp | 43 +++++++++++++------------------------------ 1 file changed, 13 insertions(+), 30 deletions(-) diff --git a/csrc/extensions.cpp b/csrc/extensions.cpp index 79e02ad5..93e87430 100644 --- a/csrc/extensions.cpp +++ b/csrc/extensions.cpp @@ -26,46 +26,29 @@ namespace dipu::dipu_ext { - -auto extRmsNorm(at::Tensor& output, - at::Tensor& inv_rms, +auto extRmsNorm(at::Tensor& output, at::Tensor& inv_rms, const at::Tensor& input, const OptionalIntArray& normalized_shape, - const at::Tensor& weight, - const at::Tensor& bias, - double eps) { + const at::Tensor& weight, const at::Tensor& bias, double eps) { at::OptionalIntArrayRef normalized_shape_at = *normalized_shape; - callDiopi(diopiRMSNorm, output, inv_rms, input, normalized_shape_at, weight, bias, eps); + callDiopi(diopiRMSNorm, output, inv_rms, input, normalized_shape_at, weight, + bias, eps); return std::make_tuple(std::move(output), std::move(inv_rms)); } - -auto extRmsNormBackward(at::Tensor& grad_input, - at::Tensor& grad_weight, - at::Tensor& grad_bias, - const at::Tensor& grad_output, - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& bias, - const at::Tensor& inv_rms, - const OptionalIntArray& normalized_shape, - double eps) { +auto extRmsNormBackward(at::Tensor& grad_input, at::Tensor& grad_weight, + at::Tensor& grad_bias, const at::Tensor& grad_output, + const at::Tensor& input, const at::Tensor& weight, + const at::Tensor& bias, const at::Tensor& inv_rms, + const OptionalIntArray& normalized_shape, double eps) { at::OptionalIntArrayRef normalized_shape_at = *normalized_shape; - callDiopi(diopiRMSNormBackward, - grad_input, - grad_weight, - grad_bias, - grad_output, - input, - weight, - bias, - inv_rms, - normalized_shape_at, + callDiopi(diopiRMSNormBackward, grad_input, grad_weight, grad_bias, + grad_output, input, weight, bias, inv_rms, normalized_shape_at, eps); - return std::make_tuple(std::move(grad_input), std::move(grad_weight), std::move(grad_bias)); + return std::make_tuple(std::move(grad_input), std::move(grad_weight), + std::move(grad_bias)); } - void extApplyRotary(at::Tensor output, const at::Tensor& input, const at::Tensor& cos, const at::Tensor& sin, const bool conj, const bool interleaved) { From 6dad1ed48a10197f0785c3e4cb8b42c4e0b8ff1b Mon Sep 17 00:00:00 2001 From: zhangzefeng92 Date: Wed, 27 Mar 2024 10:19:14 +0800 Subject: [PATCH 17/39] fix python lint --- .../internlm_ops/rms_norm/deeplink.py | 57 +++++++++++++------ tests/test_rms_lightlm.py | 2 +- 2 files changed, 42 insertions(+), 17 deletions(-) diff --git a/deeplink_ext/internlm_ops/rms_norm/deeplink.py b/deeplink_ext/internlm_ops/rms_norm/deeplink.py index f9342fc0..82142e17 100644 --- a/deeplink_ext/internlm_ops/rms_norm/deeplink.py +++ b/deeplink_ext/internlm_ops/rms_norm/deeplink.py @@ -12,7 +12,9 @@ class _DeepLinkRMSNormFunction(torch.autograd.Function): def forward(ctx, hidden_states, weight, bias, eps): output = torch.empty_like(hidden_states) inv_rms_shape = list(hidden_states.shape[:-1], 1) - inv_rms = torch.empty(inv_rms_shape, dtype=hidden_states.dtype, device=hidden_states.device) + inv_rms = torch.empty( + inv_rms_shape, dtype=hidden_states.dtype, device=hidden_states.device + ) ext.rms_norm(output, inv_rms, hidden_states, None, weight, bias, eps) ctx.save_for_backward(hidden_states, inv_rms, weight, bias, torch.tensor(eps)) @@ -22,12 +24,23 @@ def forward(ctx, hidden_states, weight, bias, eps): def backward(ctx, grad_output): hidden_states, inv_rms, weight, bias, eps_tensor = ctx.saved_tensors eps = eps_tensor.item() - + grad_input = torch.empty_like(hidden_states) grad_weight = torch.empty_like(weight) grad_bias = torch.empty_like(bias) - ext.rms_norm_backward(grad_input, grad_weight, grad_bias, grad_output, hidden_states, weight, bias, inv_rms, None, eps) + ext.rms_norm_backward( + grad_input, + grad_weight, + grad_bias, + grad_output, + hidden_states, + weight, + bias, + inv_rms, + None, + eps + ) return grad_input, grad_weight, grad_bias, None @@ -36,8 +49,18 @@ class _DeepLinkRMSNormFunctionWithNormalizedShape(torch.autograd.Function): def forward(ctx, hidden_states, weight, bias, eps, normalized_shape): output = torch.empty_like(hidden_states, dtype=torch.float32) inv_rms_shape = list(hidden_states.shape[:-1], 1) - inv_rms = torch.empty(inv_rms_shape, dtype=torch.float32, device=hidden_states.device) - ext.rms_norm(output, inv_rms, hidden_states.float(), normalized_shape, weight.float(), bias.float(), eps) + inv_rms = torch.empty( + inv_rms_shape, dtype=torch.float32, device=hidden_states.device + ) + ext.rms_norm( + output, + inv_rms, + hidden_states.float(), + normalized_shape, + weight.float(), + bias.float(), + eps + ) output = output.half() inv_rms = inv_rms.half() ctx.save_for_backward(hidden_states, inv_rms, weight, bias, torch.tensor(eps)) @@ -49,20 +72,22 @@ def backward(ctx, grad_output): hidden_states, inv_rms, weight, bias, eps_tensor = ctx.saved_tensors eps = eps_tensor.item() normalized_shape = ctx.intermediate_results - + grad_input = torch.empty_like(hidden_states, dtype=torch.float32) grad_weight = torch.empty_like(weight, dtype=torch.float32) grad_bias = torch.empty_like(bias, dtype=torch.float32) - ext.rms_norm_backward(grad_input, - grad_weight, - grad_bias, - grad_output.float(), - hidden_states.float(), - weight.float(), - bias.float(), - inv_rms.float(), - normalized_shape, - eps) + ext.rms_norm_backward( + grad_input, + grad_weight, + grad_bias, + grad_output.float(), + hidden_states.float(), + weight.float(), + bias.float(), + inv_rms.float(), + normalized_shape, + eps + ) grad_output = grad_output.half() hidden_states = hidden_states.half() inv_rms = inv_rms.half() diff --git a/tests/test_rms_lightlm.py b/tests/test_rms_lightlm.py index 35915e2d..31f081fa 100644 --- a/tests/test_rms_lightlm.py +++ b/tests/test_rms_lightlm.py @@ -38,7 +38,7 @@ bias, inv_rms, weight.shape, - 1e-6 + 1e-6, ) print("Output:", output) From 0a8a28d290ce69e221f8ee791da208d4def8ba51 Mon Sep 17 00:00:00 2001 From: zhangzefeng92 Date: Wed, 27 Mar 2024 10:20:39 +0800 Subject: [PATCH 18/39] fix python lint --- deeplink_ext/internlm_ops/rms_norm/deeplink.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deeplink_ext/internlm_ops/rms_norm/deeplink.py b/deeplink_ext/internlm_ops/rms_norm/deeplink.py index 82142e17..793bc02e 100644 --- a/deeplink_ext/internlm_ops/rms_norm/deeplink.py +++ b/deeplink_ext/internlm_ops/rms_norm/deeplink.py @@ -28,7 +28,7 @@ def backward(ctx, grad_output): grad_input = torch.empty_like(hidden_states) grad_weight = torch.empty_like(weight) grad_bias = torch.empty_like(bias) - + ext.rms_norm_backward( grad_input, grad_weight, From 4ef9d48a3a58ef5be9638c63542b46f43444d44a Mon Sep 17 00:00:00 2001 From: zhangzefeng92 Date: Wed, 27 Mar 2024 10:21:47 +0800 Subject: [PATCH 19/39] fix python lint --- deeplink_ext/internlm_ops/rms_norm/deeplink.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/deeplink_ext/internlm_ops/rms_norm/deeplink.py b/deeplink_ext/internlm_ops/rms_norm/deeplink.py index 793bc02e..9403340d 100644 --- a/deeplink_ext/internlm_ops/rms_norm/deeplink.py +++ b/deeplink_ext/internlm_ops/rms_norm/deeplink.py @@ -39,7 +39,7 @@ def backward(ctx, grad_output): bias, inv_rms, None, - eps + eps, ) return grad_input, grad_weight, grad_bias, None @@ -59,7 +59,7 @@ def forward(ctx, hidden_states, weight, bias, eps, normalized_shape): normalized_shape, weight.float(), bias.float(), - eps + eps, ) output = output.half() inv_rms = inv_rms.half() @@ -86,7 +86,7 @@ def backward(ctx, grad_output): bias.float(), inv_rms.float(), normalized_shape, - eps + eps, ) grad_output = grad_output.half() hidden_states = hidden_states.half() From 241492b8a1d1ff4b01ea04f9076f5e36ae5a723d Mon Sep 17 00:00:00 2001 From: zhangzefeng92 Date: Wed, 27 Mar 2024 10:23:02 +0800 Subject: [PATCH 20/39] fix python lint --- tests/test_rms_lightlm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_rms_lightlm.py b/tests/test_rms_lightlm.py index 31f081fa..a5ee2baa 100644 --- a/tests/test_rms_lightlm.py +++ b/tests/test_rms_lightlm.py @@ -23,7 +23,6 @@ ext.rms_norm(output, inv_rms, input, weight.shape, weight, bias, 1e-6) - # 使用 RMS normalization 反向传播 grad_input = torch.empty_like(grad_output) grad_weight = torch.empty_like(weight) From 5cafafebd5cc3001a5a4cfe2778592d9d313de96 Mon Sep 17 00:00:00 2001 From: zhangzefeng92 Date: Wed, 27 Mar 2024 11:07:28 +0800 Subject: [PATCH 21/39] fix rms norm --- deeplink_ext/internlm_ops/rms_norm/deeplink.py | 9 ++++----- tests/test_rms_lightlm.py | 9 ++++++++- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/deeplink_ext/internlm_ops/rms_norm/deeplink.py b/deeplink_ext/internlm_ops/rms_norm/deeplink.py index 9403340d..d42208bc 100644 --- a/deeplink_ext/internlm_ops/rms_norm/deeplink.py +++ b/deeplink_ext/internlm_ops/rms_norm/deeplink.py @@ -11,11 +11,11 @@ class _DeepLinkRMSNormFunction(torch.autograd.Function): @staticmethod def forward(ctx, hidden_states, weight, bias, eps): output = torch.empty_like(hidden_states) - inv_rms_shape = list(hidden_states.shape[:-1], 1) + inv_rms_shape = list(hidden_states.shape[:-1]) + [1] inv_rms = torch.empty( inv_rms_shape, dtype=hidden_states.dtype, device=hidden_states.device ) - ext.rms_norm(output, inv_rms, hidden_states, None, weight, bias, eps) + ext.rms_norm(output, inv_rms, hidden_states, weight.shape, weight, bias, eps) ctx.save_for_backward(hidden_states, inv_rms, weight, bias, torch.tensor(eps)) return output @@ -28,7 +28,6 @@ def backward(ctx, grad_output): grad_input = torch.empty_like(hidden_states) grad_weight = torch.empty_like(weight) grad_bias = torch.empty_like(bias) - ext.rms_norm_backward( grad_input, grad_weight, @@ -38,7 +37,7 @@ def backward(ctx, grad_output): weight, bias, inv_rms, - None, + weight.shape, eps, ) return grad_input, grad_weight, grad_bias, None @@ -48,7 +47,7 @@ class _DeepLinkRMSNormFunctionWithNormalizedShape(torch.autograd.Function): @staticmethod def forward(ctx, hidden_states, weight, bias, eps, normalized_shape): output = torch.empty_like(hidden_states, dtype=torch.float32) - inv_rms_shape = list(hidden_states.shape[:-1], 1) + inv_rms_shape = list(hidden_states.shape[:-1]) + [1] inv_rms = torch.empty( inv_rms_shape, dtype=torch.float32, device=hidden_states.device ) diff --git a/tests/test_rms_lightlm.py b/tests/test_rms_lightlm.py index a5ee2baa..0e6b911c 100644 --- a/tests/test_rms_lightlm.py +++ b/tests/test_rms_lightlm.py @@ -22,7 +22,6 @@ inv_rms = torch.empty(inv_rms_shape, dtype=torch.float32, device=input.device) ext.rms_norm(output, inv_rms, input, weight.shape, weight, bias, 1e-6) - # 使用 RMS normalization 反向传播 grad_input = torch.empty_like(grad_output) grad_weight = torch.empty_like(weight) @@ -44,5 +43,13 @@ print("Grad Input:", grad_input) print("Grad Weight:", grad_weight) print("Grad Bias:", grad_bias) + +input.requires_grad_(True) +weight.requires_grad_(True) +bias.requires_grad_(True) b = input * torch.rsqrt(input.pow(2).mean(-1, keepdim=True) + 1e-6) * weight +grads = torch.autograd.grad(b, [input, weight, bias], grad_output, allow_unused=True) assert torch.allclose(output, b) +assert torch.allclose(grad_input, grads[0]) +assert torch.allclose(grad_weight, grads[1]) +# assert torch.allclose(grad_bias, grads[2]) From 3d9a805f764e4184df2047b8e99d017ef3ee15fc Mon Sep 17 00:00:00 2001 From: zhangzefeng92 Date: Thu, 28 Mar 2024 09:21:47 +0800 Subject: [PATCH 22/39] modify rms norm --- csrc/extensions.cpp | 18 --- csrc/pybind_type_cast.h | 24 ---- .../internlm_ops/rms_norm/deeplink.py | 135 +++++++++++------- deeplink_ext/patch_lightllm.py | 13 +- tests/test_rms_internlm.py | 6 +- tests/test_rms_lightlm.py | 26 +--- 6 files changed, 106 insertions(+), 116 deletions(-) diff --git a/csrc/extensions.cpp b/csrc/extensions.cpp index 93e87430..370b163c 100644 --- a/csrc/extensions.cpp +++ b/csrc/extensions.cpp @@ -197,21 +197,6 @@ void extApplyPenalty(at::Tensor& logits, const at::Tensor& presence_penalty, p_token_ids, p_token_counts, p_cumsum_seq_len, p_max_len_in_batch); } -// For lightllm, rms_norm reuses the diopi implementation of internlm -auto extRmsNormLightllm(const at::Tensor& x, const at::Tensor& weight, - float eps) { - at::ScalarType acc_type = x.scalar_type(); - if (x.scalar_type() == at::kBFloat16 || x.scalar_type() == at::kHalf) { - acc_type = at::kFloat; - } - auto inv_rms = at::empty_like(x, acc_type); - auto out = at::empty_like(x); - auto bias = at::empty_like(weight); - at::OptionalIntArrayRef normalized_shape = weight.sizes(); - callDiopi(diopiRMSNorm, out, inv_rms, x, normalized_shape, weight, bias, eps); - return out; -} - // For lightllm, rotary_embedding reuses the diopi implementation of internlm void extRotaryEmb(at::Tensor& q, const at::Tensor& cos, const at::Tensor& sin) { auto seq_len = q.size(0); @@ -229,9 +214,6 @@ void extRotaryEmb(at::Tensor& q, const at::Tensor& cos, const at::Tensor& sin) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { if (&diopiRMSNorm != nullptr) { // Check if weak symbol defined m.def("rms_norm", &extRmsNorm, "deeplink ext_rms_norm"); - m.def("rms_norm_lightllm", &extRmsNormLightllm, - "deeplink ext_rms_norm for lightllm", py::arg("x"), py::arg("weight"), - py::arg("eps")); } if (&diopiRMSNormBackward != nullptr) { m.def("rms_norm_backward", &extRmsNormBackward, diff --git a/csrc/pybind_type_cast.h b/csrc/pybind_type_cast.h index 6d128981..61e8f484 100644 --- a/csrc/pybind_type_cast.h +++ b/csrc/pybind_type_cast.h @@ -21,28 +21,4 @@ using OptionalIntArray = c10::optional; } // namespace dipu::dipu_ext -namespace pybind11::detail { - -namespace py = pybind11; - -template <> -struct type_caster { - public: - PYBIND11_TYPE_CASTER(dipu::dipu_ext::OptionalIntArray, _("OptionalIntArray")); - - bool load(py::handle src, bool /*unused*/) { - if (PyList_Check(src.ptr())) { - value = py::cast(src); - return true; - } - if (src.is_none()) { - value = c10::nullopt; - return true; - } - return false; - } -}; - -} // namespace pybind11::detail - #endif /* end of include guard: PYBIND_TYPE_CAST_H_PXMGELYW */ diff --git a/deeplink_ext/internlm_ops/rms_norm/deeplink.py b/deeplink_ext/internlm_ops/rms_norm/deeplink.py index d42208bc..e6841775 100644 --- a/deeplink_ext/internlm_ops/rms_norm/deeplink.py +++ b/deeplink_ext/internlm_ops/rms_norm/deeplink.py @@ -1,64 +1,110 @@ # Copyright (c) 2024, DeepLink. import torch -import deeplink_ext.cpp_extensions as ext +import deeplink_ext.cpp_extensions as cpp_ext -assert hasattr(ext, "rms_norm") +assert hasattr(cpp_ext, "rms_norm") -# 定义自定义的 autograd 函数 -class _DeepLinkRMSNormFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, hidden_states, weight, bias, eps): - output = torch.empty_like(hidden_states) - inv_rms_shape = list(hidden_states.shape[:-1]) + [1] - inv_rms = torch.empty( - inv_rms_shape, dtype=hidden_states.dtype, device=hidden_states.device - ) - ext.rms_norm(output, inv_rms, hidden_states, weight.shape, weight, bias, eps) +def rms_norm_out(output, inv_rms, input, normalized_shape, weight, bias, eps): + if None == normalized_shape: + cpp_ext.rms_norm(output, inv_rms, input, weight.shape, weight, bias, eps) + else: + cpp_ext.rms_norm(output, inv_rms, input, normalized_shape, weight, bias, eps) - ctx.save_for_backward(hidden_states, inv_rms, weight, bias, torch.tensor(eps)) - return output - @staticmethod - def backward(ctx, grad_output): - hidden_states, inv_rms, weight, bias, eps_tensor = ctx.saved_tensors - eps = eps_tensor.item() +def rms_norm(input, normalized_shape, weight, bias, eps): + output = torch.empty_like(input) + inv_rms_shape = list(input.shape[:-1]) + [1] + inv_rms = torch.empty(inv_rms_shape, dtype=input.dtype, device=input.device) + rms_norm_out(output, inv_rms, input, normalized_shape, weight, bias, eps) + + return [output, inv_rms] - grad_input = torch.empty_like(hidden_states) - grad_weight = torch.empty_like(weight) - grad_bias = torch.empty_like(bias) - ext.rms_norm_backward( + +def rms_norm_backward_out( + grad_input, + grad_weight, + grad_bias, + grad_output, + input, + weight, + bias, + inv_rms, + normalized_shape, + eps, +): + if None == normalized_shape: + cpp_ext.rms_norm_backward( grad_input, grad_weight, grad_bias, grad_output, - hidden_states, + input, weight, bias, inv_rms, weight.shape, eps, ) + else: + cpp_ext.rms_norm_backward( + grad_input, + grad_weight, + grad_bias, + grad_output, + input, + weight, + bias, + inv_rms, + normalized_shape, + eps, + ) + + +def rms_norm_backward(input, grad_output, inv_rms, normalized_shape, weight, bias, eps): + grad_input = torch.empty_like(input) + grad_weight = torch.empty_like(weight) + grad_bias = torch.empty_like(bias) + rms_norm_backward_out( + grad_input, + grad_weight, + grad_bias, + grad_output, + input, + weight, + bias, + inv_rms, + normalized_shape, + eps, + ) + + return [grad_input, grad_weight, grad_bias] + + +# 定义自定义的 autograd 函数 +class _DeepLinkRMSNormFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, hidden_states, weight, bias, eps): + output, inv_rms = rms_norm(hidden_states, None, weight, bias, eps) + ctx.save_for_backward(hidden_states, inv_rms, weight, bias, torch.tensor(eps)) + return output + + @staticmethod + def backward(ctx, grad_output): + hidden_states, inv_rms, weight, bias, eps_tensor = ctx.saved_tensors + eps = eps_tensor.item() + grad_input, grad_weight, grad_bias = rms_norm_backward( + hidden_states, grad_output, inv_rms, None, weight, bias, eps + ) return grad_input, grad_weight, grad_bias, None class _DeepLinkRMSNormFunctionWithNormalizedShape(torch.autograd.Function): @staticmethod def forward(ctx, hidden_states, weight, bias, eps, normalized_shape): - output = torch.empty_like(hidden_states, dtype=torch.float32) - inv_rms_shape = list(hidden_states.shape[:-1]) + [1] - inv_rms = torch.empty( - inv_rms_shape, dtype=torch.float32, device=hidden_states.device - ) - ext.rms_norm( - output, - inv_rms, - hidden_states.float(), - normalized_shape, - weight.float(), - bias.float(), - eps, + output, inv_rms = rms_norm( + hidden_states.float(), normalized_shape, weight.float(), bias.float(), eps ) output = output.half() inv_rms = inv_rms.half() @@ -72,24 +118,15 @@ def backward(ctx, grad_output): eps = eps_tensor.item() normalized_shape = ctx.intermediate_results - grad_input = torch.empty_like(hidden_states, dtype=torch.float32) - grad_weight = torch.empty_like(weight, dtype=torch.float32) - grad_bias = torch.empty_like(bias, dtype=torch.float32) - ext.rms_norm_backward( - grad_input, - grad_weight, - grad_bias, - grad_output.float(), + grad_input, grad_weight, grad_bias = rms_norm_backward( hidden_states.float(), - weight.float(), - bias.float(), + grad_output.float(), inv_rms.float(), normalized_shape, + weight.float(), + bias.float(), eps, ) - grad_output = grad_output.half() - hidden_states = hidden_states.half() - inv_rms = inv_rms.half() return grad_input, grad_weight, grad_bias, None, None diff --git a/deeplink_ext/patch_lightllm.py b/deeplink_ext/patch_lightllm.py index d371870a..4aeee369 100644 --- a/deeplink_ext/patch_lightllm.py +++ b/deeplink_ext/patch_lightllm.py @@ -51,7 +51,18 @@ def patch_token_softmax_reducev_inference(): ) def patch_rms_norm_lightllm(): - rms_norm_pack.rmsnorm_forward = ext.rms_norm_lightllm + import torch + + def rms_norm_lightllm(x, weight, eps): + output = torch.empty_like(x) + inv_rms_dtype = torch.float16 if x.dtype == torch.bfloat16 else x.dtype + inv_rms = torch.empty_like(x, dtype=inv_rms_dtype) + bias = torch.empty_like(weight) + ext.rms_norm(output, inv_rms, x, weight.shape, weight, bias, eps) + + return output + + rms_norm_pack.rmsnorm_forward = rms_norm_lightllm def patch_rotary_emb(): rotary_emb_pack.rotary_emb_fwd = ext.rotary_emb diff --git a/tests/test_rms_internlm.py b/tests/test_rms_internlm.py index 72bca9ef..8d1a7807 100644 --- a/tests/test_rms_internlm.py +++ b/tests/test_rms_internlm.py @@ -5,7 +5,7 @@ import deeplink_ext.internlm_ops.rms_norm as ext -def test_rms_norm(BaseRmsNorm, DeeplinkRmsNorm, rtol=1e-4, atol=1e-3): +def rms_norm_test(BaseRmsNorm, DeeplinkRmsNorm, rtol=1e-4, atol=1e-3): x_base = torch.randn(5, 5, requires_grad=True).cuda() x_base.retain_grad() @@ -29,9 +29,9 @@ def test_rms_norm(BaseRmsNorm, DeeplinkRmsNorm, rtol=1e-4, atol=1e-3): print( "Test case: normalized_shape == None: grad_inputs closed ? ", - test_rms_norm(ext.fallback.RMSNorm, ext.DeepLinkRMSNorm), + rms_norm_test(ext.fallback.RMSNorm, ext.DeepLinkRMSNorm), ) print( "Test case: normalized_shape == weight.size(): grad_inputs closed ? ", - test_rms_norm(ext.fallback.RMSNorm, ext.DeepLinkRMSNormWithNormalizedShape), + rms_norm_test(ext.fallback.RMSNorm, ext.DeepLinkRMSNormWithNormalizedShape), ) diff --git a/tests/test_rms_lightlm.py b/tests/test_rms_lightlm.py index 0e6b911c..ba57369b 100644 --- a/tests/test_rms_lightlm.py +++ b/tests/test_rms_lightlm.py @@ -1,7 +1,7 @@ # Copyright (c) 2023, DeepLink. import torch -import deeplink_ext.cpp_extensions as ext +from deeplink_ext.internlm_ops.rms_norm.deeplink import rms_norm, rms_norm_backward # 定义输入张量 input = torch.randn(5, 5, requires_grad=True).cuda() @@ -17,26 +17,10 @@ normalized_shape = torch.tensor([5, 5], dtype=torch.long).cuda() print(input.is_dipu) -output = torch.empty_like(input) -inv_rms_shape = list(input.shape[:-1]) + [1] -inv_rms = torch.empty(inv_rms_shape, dtype=torch.float32, device=input.device) -ext.rms_norm(output, inv_rms, input, weight.shape, weight, bias, 1e-6) - -# 使用 RMS normalization 反向传播 -grad_input = torch.empty_like(grad_output) -grad_weight = torch.empty_like(weight) -grad_bias = torch.empty_like(bias) -ext.rms_norm_backward( - grad_input, - grad_weight, - grad_bias, - grad_output, - input, - weight, - bias, - inv_rms, - weight.shape, - 1e-6, +output, inv_rms = rms_norm(input, None, weight, bias, 1e-6) + +grad_input, grad_weight, grad_bias = rms_norm_backward( + input, grad_output, inv_rms, None, weight, bias, 1e-6 ) print("Output:", output) From 27da2ecc43fb5d7646983c857c0e3d7de6409538 Mon Sep 17 00:00:00 2001 From: zhangzefeng92 Date: Mon, 1 Apr 2024 10:41:20 +0800 Subject: [PATCH 23/39] modify rms norm --- deeplink_ext/common/rms_norm/__init__.py | 4 + deeplink_ext/common/rms_norm/deeplink.py | 79 ++++++++++++++++++ .../internlm_ops/rms_norm/deeplink.py | 80 +------------------ deeplink_ext/patch_lightllm.py | 14 +--- 4 files changed, 86 insertions(+), 91 deletions(-) create mode 100644 deeplink_ext/common/rms_norm/__init__.py create mode 100644 deeplink_ext/common/rms_norm/deeplink.py diff --git a/deeplink_ext/common/rms_norm/__init__.py b/deeplink_ext/common/rms_norm/__init__.py new file mode 100644 index 00000000..81b96b65 --- /dev/null +++ b/deeplink_ext/common/rms_norm/__init__.py @@ -0,0 +1,4 @@ +from .deeplink import rms_norm_out, rms_norm, rms_norm_backward_out, rms_norm_backward + + +all = ["rms_norm_out", "rms_norm", "rms_norm_backward_out", "rms_norm_backward"] \ No newline at end of file diff --git a/deeplink_ext/common/rms_norm/deeplink.py b/deeplink_ext/common/rms_norm/deeplink.py new file mode 100644 index 00000000..ac83a411 --- /dev/null +++ b/deeplink_ext/common/rms_norm/deeplink.py @@ -0,0 +1,79 @@ +import torch +import deeplink_ext.cpp_extensions as cpp_ext + + +def rms_norm_out(output, inv_rms, input, normalized_shape, weight, bias, eps): + if None == normalized_shape: + cpp_ext.rms_norm(output, inv_rms, input, weight.shape, weight, bias, eps) + else: + cpp_ext.rms_norm(output, inv_rms, input, normalized_shape, weight, bias, eps) + + +def rms_norm(input, normalized_shape, weight, bias, eps): + output = torch.empty_like(input) + inv_rms_shape = list(input.shape[:-1]) + [1] + inv_rms = torch.empty(inv_rms_shape, dtype=input.dtype, device=input.device) + rms_norm_out(output, inv_rms, input, normalized_shape, weight, bias, eps) + + return [output, inv_rms] + + +def rms_norm_backward_out( + grad_input, + grad_weight, + grad_bias, + grad_output, + input, + weight, + bias, + inv_rms, + normalized_shape, + eps, +): + if None == normalized_shape: + cpp_ext.rms_norm_backward( + grad_input, + grad_weight, + grad_bias, + grad_output, + input, + weight, + bias, + inv_rms, + weight.shape, + eps, + ) + else: + cpp_ext.rms_norm_backward( + grad_input, + grad_weight, + grad_bias, + grad_output, + input, + weight, + bias, + inv_rms, + normalized_shape, + eps, + ) + + +def rms_norm_backward(input, grad_output, inv_rms, normalized_shape, weight, bias, eps): + grad_input = torch.empty_like(input) + grad_weight = torch.empty_like(weight) + grad_bias = torch.empty_like(bias) + rms_norm_backward_out( + grad_input, + grad_weight, + grad_bias, + grad_output, + input, + weight, + bias, + inv_rms, + normalized_shape, + eps, + ) + + return [grad_input, grad_weight, grad_bias] + diff --git a/deeplink_ext/internlm_ops/rms_norm/deeplink.py b/deeplink_ext/internlm_ops/rms_norm/deeplink.py index e6841775..8eef9e0f 100644 --- a/deeplink_ext/internlm_ops/rms_norm/deeplink.py +++ b/deeplink_ext/internlm_ops/rms_norm/deeplink.py @@ -1,85 +1,7 @@ # Copyright (c) 2024, DeepLink. import torch -import deeplink_ext.cpp_extensions as cpp_ext - -assert hasattr(cpp_ext, "rms_norm") - - -def rms_norm_out(output, inv_rms, input, normalized_shape, weight, bias, eps): - if None == normalized_shape: - cpp_ext.rms_norm(output, inv_rms, input, weight.shape, weight, bias, eps) - else: - cpp_ext.rms_norm(output, inv_rms, input, normalized_shape, weight, bias, eps) - - -def rms_norm(input, normalized_shape, weight, bias, eps): - output = torch.empty_like(input) - inv_rms_shape = list(input.shape[:-1]) + [1] - inv_rms = torch.empty(inv_rms_shape, dtype=input.dtype, device=input.device) - rms_norm_out(output, inv_rms, input, normalized_shape, weight, bias, eps) - - return [output, inv_rms] - - -def rms_norm_backward_out( - grad_input, - grad_weight, - grad_bias, - grad_output, - input, - weight, - bias, - inv_rms, - normalized_shape, - eps, -): - if None == normalized_shape: - cpp_ext.rms_norm_backward( - grad_input, - grad_weight, - grad_bias, - grad_output, - input, - weight, - bias, - inv_rms, - weight.shape, - eps, - ) - else: - cpp_ext.rms_norm_backward( - grad_input, - grad_weight, - grad_bias, - grad_output, - input, - weight, - bias, - inv_rms, - normalized_shape, - eps, - ) - - -def rms_norm_backward(input, grad_output, inv_rms, normalized_shape, weight, bias, eps): - grad_input = torch.empty_like(input) - grad_weight = torch.empty_like(weight) - grad_bias = torch.empty_like(bias) - rms_norm_backward_out( - grad_input, - grad_weight, - grad_bias, - grad_output, - input, - weight, - bias, - inv_rms, - normalized_shape, - eps, - ) - - return [grad_input, grad_weight, grad_bias] +from deeplink_ext.common.rms_norm.deeplink import rms_norm, rms_norm_backward # 定义自定义的 autograd 函数 diff --git a/deeplink_ext/patch_lightllm.py b/deeplink_ext/patch_lightllm.py index 4aeee369..f97d31d4 100644 --- a/deeplink_ext/patch_lightllm.py +++ b/deeplink_ext/patch_lightllm.py @@ -51,18 +51,8 @@ def patch_token_softmax_reducev_inference(): ) def patch_rms_norm_lightllm(): - import torch - - def rms_norm_lightllm(x, weight, eps): - output = torch.empty_like(x) - inv_rms_dtype = torch.float16 if x.dtype == torch.bfloat16 else x.dtype - inv_rms = torch.empty_like(x, dtype=inv_rms_dtype) - bias = torch.empty_like(weight) - ext.rms_norm(output, inv_rms, x, weight.shape, weight, bias, eps) - - return output - - rms_norm_pack.rmsnorm_forward = rms_norm_lightllm + from .common.rms_norm.deeplink import rms_norm + rms_norm_pack.rmsnorm_forward = rms_norm def patch_rotary_emb(): rotary_emb_pack.rotary_emb_fwd = ext.rotary_emb From 3f894a28592b8037c37fa0d3d6c268b8e5f1ea9b Mon Sep 17 00:00:00 2001 From: zhangzefeng92 Date: Mon, 1 Apr 2024 10:43:55 +0800 Subject: [PATCH 24/39] modify rms norm --- deeplink_ext/common/rms_norm/__init__.py | 2 +- deeplink_ext/common/rms_norm/deeplink.py | 1 - deeplink_ext/patch_lightllm.py | 1 + 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/deeplink_ext/common/rms_norm/__init__.py b/deeplink_ext/common/rms_norm/__init__.py index 81b96b65..f91583f3 100644 --- a/deeplink_ext/common/rms_norm/__init__.py +++ b/deeplink_ext/common/rms_norm/__init__.py @@ -1,4 +1,4 @@ from .deeplink import rms_norm_out, rms_norm, rms_norm_backward_out, rms_norm_backward -all = ["rms_norm_out", "rms_norm", "rms_norm_backward_out", "rms_norm_backward"] \ No newline at end of file +all = ["rms_norm_out", "rms_norm", "rms_norm_backward_out", "rms_norm_backward"] diff --git a/deeplink_ext/common/rms_norm/deeplink.py b/deeplink_ext/common/rms_norm/deeplink.py index ac83a411..f196d297 100644 --- a/deeplink_ext/common/rms_norm/deeplink.py +++ b/deeplink_ext/common/rms_norm/deeplink.py @@ -76,4 +76,3 @@ def rms_norm_backward(input, grad_output, inv_rms, normalized_shape, weight, bia ) return [grad_input, grad_weight, grad_bias] - diff --git a/deeplink_ext/patch_lightllm.py b/deeplink_ext/patch_lightllm.py index f97d31d4..f8c89d7e 100644 --- a/deeplink_ext/patch_lightllm.py +++ b/deeplink_ext/patch_lightllm.py @@ -52,6 +52,7 @@ def patch_token_softmax_reducev_inference(): def patch_rms_norm_lightllm(): from .common.rms_norm.deeplink import rms_norm + rms_norm_pack.rmsnorm_forward = rms_norm def patch_rotary_emb(): From 52bc92709d9e3b0ccb5b9ed1f918752fcfff3b9a Mon Sep 17 00:00:00 2001 From: root Date: Mon, 1 Apr 2024 04:25:12 +0000 Subject: [PATCH 25/39] lint --- deeplink_ext/internlm_ops/__init__.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/deeplink_ext/internlm_ops/__init__.py b/deeplink_ext/internlm_ops/__init__.py index 302d1e7f..c64c3a9a 100644 --- a/deeplink_ext/internlm_ops/__init__.py +++ b/deeplink_ext/internlm_ops/__init__.py @@ -5,8 +5,10 @@ try: from .rms_norm import RMSNorm, RMSNormWithNormalizedShape except: - from .rms_norm_fallback import RMSNorm as RMSNorm, RMSNorm as RMSNormWithNormalizedShape + from .rms_norm_fallback import ( + RMSNorm as RMSNorm, + RMSNorm as RMSNormWithNormalizedShape, + ) -__all__ =["mha","rotary", RMSNorm, RMSNormWithNormalizedShape] - +__all__ = ["mha", "rotary", RMSNorm, RMSNormWithNormalizedShape] From 5ac11cacfbd2788afe284a4ad5f2907e14397013 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 1 Apr 2024 04:43:55 +0000 Subject: [PATCH 26/39] delete the duplicated --- deeplink_ext/common/rms_norm/__init__.py | 4 -- deeplink_ext/common/rms_norm/deeplink.py | 78 ------------------------ 2 files changed, 82 deletions(-) delete mode 100644 deeplink_ext/common/rms_norm/__init__.py delete mode 100644 deeplink_ext/common/rms_norm/deeplink.py diff --git a/deeplink_ext/common/rms_norm/__init__.py b/deeplink_ext/common/rms_norm/__init__.py deleted file mode 100644 index f91583f3..00000000 --- a/deeplink_ext/common/rms_norm/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .deeplink import rms_norm_out, rms_norm, rms_norm_backward_out, rms_norm_backward - - -all = ["rms_norm_out", "rms_norm", "rms_norm_backward_out", "rms_norm_backward"] diff --git a/deeplink_ext/common/rms_norm/deeplink.py b/deeplink_ext/common/rms_norm/deeplink.py deleted file mode 100644 index f196d297..00000000 --- a/deeplink_ext/common/rms_norm/deeplink.py +++ /dev/null @@ -1,78 +0,0 @@ -import torch -import deeplink_ext.cpp_extensions as cpp_ext - - -def rms_norm_out(output, inv_rms, input, normalized_shape, weight, bias, eps): - if None == normalized_shape: - cpp_ext.rms_norm(output, inv_rms, input, weight.shape, weight, bias, eps) - else: - cpp_ext.rms_norm(output, inv_rms, input, normalized_shape, weight, bias, eps) - - -def rms_norm(input, normalized_shape, weight, bias, eps): - output = torch.empty_like(input) - inv_rms_shape = list(input.shape[:-1]) + [1] - inv_rms = torch.empty(inv_rms_shape, dtype=input.dtype, device=input.device) - rms_norm_out(output, inv_rms, input, normalized_shape, weight, bias, eps) - - return [output, inv_rms] - - -def rms_norm_backward_out( - grad_input, - grad_weight, - grad_bias, - grad_output, - input, - weight, - bias, - inv_rms, - normalized_shape, - eps, -): - if None == normalized_shape: - cpp_ext.rms_norm_backward( - grad_input, - grad_weight, - grad_bias, - grad_output, - input, - weight, - bias, - inv_rms, - weight.shape, - eps, - ) - else: - cpp_ext.rms_norm_backward( - grad_input, - grad_weight, - grad_bias, - grad_output, - input, - weight, - bias, - inv_rms, - normalized_shape, - eps, - ) - - -def rms_norm_backward(input, grad_output, inv_rms, normalized_shape, weight, bias, eps): - grad_input = torch.empty_like(input) - grad_weight = torch.empty_like(weight) - grad_bias = torch.empty_like(bias) - rms_norm_backward_out( - grad_input, - grad_weight, - grad_bias, - grad_output, - input, - weight, - bias, - inv_rms, - normalized_shape, - eps, - ) - - return [grad_input, grad_weight, grad_bias] From 4a316732cc0262e24fa9e98f816a5a5f0e3660d5 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 1 Apr 2024 04:45:28 +0000 Subject: [PATCH 27/39] delete --- .../internlm_ops/rms_norm/deeplink.py | 83 ------------------- 1 file changed, 83 deletions(-) delete mode 100644 deeplink_ext/internlm_ops/rms_norm/deeplink.py diff --git a/deeplink_ext/internlm_ops/rms_norm/deeplink.py b/deeplink_ext/internlm_ops/rms_norm/deeplink.py deleted file mode 100644 index 8eef9e0f..00000000 --- a/deeplink_ext/internlm_ops/rms_norm/deeplink.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright (c) 2024, DeepLink. - -import torch -from deeplink_ext.common.rms_norm.deeplink import rms_norm, rms_norm_backward - - -# 定义自定义的 autograd 函数 -class _DeepLinkRMSNormFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, hidden_states, weight, bias, eps): - output, inv_rms = rms_norm(hidden_states, None, weight, bias, eps) - ctx.save_for_backward(hidden_states, inv_rms, weight, bias, torch.tensor(eps)) - return output - - @staticmethod - def backward(ctx, grad_output): - hidden_states, inv_rms, weight, bias, eps_tensor = ctx.saved_tensors - eps = eps_tensor.item() - grad_input, grad_weight, grad_bias = rms_norm_backward( - hidden_states, grad_output, inv_rms, None, weight, bias, eps - ) - return grad_input, grad_weight, grad_bias, None - - -class _DeepLinkRMSNormFunctionWithNormalizedShape(torch.autograd.Function): - @staticmethod - def forward(ctx, hidden_states, weight, bias, eps, normalized_shape): - output, inv_rms = rms_norm( - hidden_states.float(), normalized_shape, weight.float(), bias.float(), eps - ) - output = output.half() - inv_rms = inv_rms.half() - ctx.save_for_backward(hidden_states, inv_rms, weight, bias, torch.tensor(eps)) - ctx.intermediate_results = normalized_shape - return output - - @staticmethod - def backward(ctx, grad_output): - hidden_states, inv_rms, weight, bias, eps_tensor = ctx.saved_tensors - eps = eps_tensor.item() - normalized_shape = ctx.intermediate_results - - grad_input, grad_weight, grad_bias = rms_norm_backward( - hidden_states.float(), - grad_output.float(), - inv_rms.float(), - normalized_shape, - weight.float(), - bias.float(), - eps, - ) - return grad_input, grad_weight, grad_bias, None, None - - -# 定义一个 nn.Module 包裹这个自定义函数 -class DeepLinkRMSNorm(torch.nn.Module): - def __init__(self, hidden_size, eps=1e-5): - super().__init__() - self.weight = torch.nn.Parameter(torch.ones(hidden_size)) - self.bias = torch.zeros(hidden_size).cuda() - self.variance_epsilon = eps - - def forward(self, hidden_states): - return _DeepLinkRMSNormFunction.apply( - hidden_states, self.weight, self.bias, self.variance_epsilon - ) - - -class DeepLinkRMSNormWithNormalizedShape(torch.nn.Module): - def __init__(self, hidden_size, eps=1e-5): - super().__init__() - self.weight = torch.nn.Parameter(torch.ones(hidden_size)) - self.bias = torch.zeros(hidden_size).cuda() - self.variance_epsilon = eps - - def forward(self, hidden_states): - return _DeepLinkRMSNormFunctionWithNormalizedShape.apply( - hidden_states, - self.weight, - self.bias, - self.variance_epsilon, - self.weight.size(), - ) From 8ab1be00a199e224a7694ca368702556e098616b Mon Sep 17 00:00:00 2001 From: Zhangzefeng Date: Mon, 1 Apr 2024 12:51:18 +0800 Subject: [PATCH 28/39] Update __init__.py --- deeplink_ext/common/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deeplink_ext/common/__init__.py b/deeplink_ext/common/__init__.py index f91583f3..2d3353d9 100644 --- a/deeplink_ext/common/__init__.py +++ b/deeplink_ext/common/__init__.py @@ -1,4 +1,4 @@ from .deeplink import rms_norm_out, rms_norm, rms_norm_backward_out, rms_norm_backward -all = ["rms_norm_out", "rms_norm", "rms_norm_backward_out", "rms_norm_backward"] +__all__ = ["rms_norm_out", "rms_norm", "rms_norm_backward_out", "rms_norm_backward"] From c314b13d0af5f3eaac114ad10e63542a29c420c5 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 1 Apr 2024 04:52:27 +0000 Subject: [PATCH 29/39] modify test --- tests/test_rms_internlm.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/test_rms_internlm.py b/tests/test_rms_internlm.py index 8d1a7807..3dde74de 100644 --- a/tests/test_rms_internlm.py +++ b/tests/test_rms_internlm.py @@ -2,7 +2,8 @@ import torch import numpy as np -import deeplink_ext.internlm_ops.rms_norm as ext +from deeplink_ext.internlm_ops.rms_norm import RMSNorm, RMSNormWithNormalizedShape +from deeplink_ext.internlm_ops.rms_norm_fallback import RMSNorm as RMSNorm_fb def rms_norm_test(BaseRmsNorm, DeeplinkRmsNorm, rtol=1e-4, atol=1e-3): @@ -29,9 +30,9 @@ def rms_norm_test(BaseRmsNorm, DeeplinkRmsNorm, rtol=1e-4, atol=1e-3): print( "Test case: normalized_shape == None: grad_inputs closed ? ", - rms_norm_test(ext.fallback.RMSNorm, ext.DeepLinkRMSNorm), + rms_norm_test(RMSNorm_fb, RMSNorm), ) print( "Test case: normalized_shape == weight.size(): grad_inputs closed ? ", - rms_norm_test(ext.fallback.RMSNorm, ext.DeepLinkRMSNormWithNormalizedShape), + rms_norm_test(RMSNorm_fb, RMSNormWithNormalizedShape), ) From dc674ca3436627ac4884e804ad716488a466e531 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 1 Apr 2024 04:56:50 +0000 Subject: [PATCH 30/39] modify --- deeplink_ext/patch_lightllm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deeplink_ext/patch_lightllm.py b/deeplink_ext/patch_lightllm.py index b3cbfedc..3b95024c 100644 --- a/deeplink_ext/patch_lightllm.py +++ b/deeplink_ext/patch_lightllm.py @@ -51,7 +51,7 @@ def patch_token_softmax_reducev_inference(): ) def patch_rms_norm_lightllm(): - from .common.rms_norm.deeplink import rms_norm + from .common.rms_norm import rms_norm rms_norm_pack.rmsnorm_forward = rms_norm From 7ce2b22b546f9806754afb7b309ba08fc249ade4 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 1 Apr 2024 05:10:01 +0000 Subject: [PATCH 31/39] modify rotary_embeding --- deeplink_ext/internlm_ops/__init__.py | 19 +++- deeplink_ext/internlm_ops/rotary/__init__.py | 13 --- deeplink_ext/internlm_ops/rotary/deeplink.py | 65 ------------- deeplink_ext/internlm_ops/rotary/fallback.py | 98 -------------------- 4 files changed, 17 insertions(+), 178 deletions(-) delete mode 100644 deeplink_ext/internlm_ops/rotary/__init__.py delete mode 100644 deeplink_ext/internlm_ops/rotary/deeplink.py delete mode 100644 deeplink_ext/internlm_ops/rotary/fallback.py diff --git a/deeplink_ext/internlm_ops/__init__.py b/deeplink_ext/internlm_ops/__init__.py index c64c3a9a..eb2f9acd 100644 --- a/deeplink_ext/internlm_ops/__init__.py +++ b/deeplink_ext/internlm_ops/__init__.py @@ -1,14 +1,29 @@ # Copyright (c) 2024, DeepLink. -from . import mha, rotary +from . import mha + + +_not_impl = "[deeplink_ext] %s is not implemented in diopi. Falling back to the slower torch implementation." + try: from .rms_norm import RMSNorm, RMSNormWithNormalizedShape except: + print( + _not_impl.format("RMSNorm or RMSNormWithNormalizedShape"), + ) from .rms_norm_fallback import ( RMSNorm as RMSNorm, RMSNorm as RMSNormWithNormalizedShape, ) -__all__ = ["mha", "rotary", RMSNorm, RMSNormWithNormalizedShape] +try: + from .rotary_embedding import apply_rotary +except: + print( _not_impl.format("apply_rotary")) + from .rotary_embeddinig_fallback import apply_rotary + + + +__all__ = ["mha", "RMSNorm", "RMSNormWithNormalizedShape", "apply_rotary"] \ No newline at end of file diff --git a/deeplink_ext/internlm_ops/rotary/__init__.py b/deeplink_ext/internlm_ops/rotary/__init__.py deleted file mode 100644 index 8ca5f9e2..00000000 --- a/deeplink_ext/internlm_ops/rotary/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) 2024, DeepLink. - -try: - from .deeplink import apply_rotary -except: - print( - "[deeplink_ext] rotary is not implemented in diopi. Falling back to the slower implementation.\n", - end="", - ) - from .fallback import apply_rotary -from . import fallback - -__all__ = ["apply_rotary", "fallback"] diff --git a/deeplink_ext/internlm_ops/rotary/deeplink.py b/deeplink_ext/internlm_ops/rotary/deeplink.py deleted file mode 100644 index 670a47b9..00000000 --- a/deeplink_ext/internlm_ops/rotary/deeplink.py +++ /dev/null @@ -1,65 +0,0 @@ -# Copyright (c) 2024, DeepLink. - -from typing import Optional, Union -import torch -from einops import rearrange -import deeplink_ext.cpp_extensions as ext - -assert hasattr(ext, "apply_rotary") - -__all__ = ["apply_rotary"] - - -def apply_rotary( - x: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - seqlen_offsets: Union[int, torch.Tensor] = 0, - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[int] = None, - interleaved=False, - inplace=False, - conjugate=False, -) -> torch.Tensor: - """ - Arguments: - x: (batch, seqlen, nheads, headdim) if cu_seqlens is None - else (total_seqlen, nheads, headdim). - cos: (seqlen_ro, rotary_dim / 2) - sin: (seqlen_ro, rotary_dim / 2) - seqlen_offsets: integer or integer tensor of size (batch,) - cu_seqlens: (batch + 1,) or None - max_seqlen: int - Returns: - y: (batch, seqlen, nheads, headdim) - """ - if seqlen_offsets != 0 and cu_seqlens is None and max_seqlen is None: - raise NotImplementedError( - "apply_rotary: seqlen_offsets, cu_seqlens and max_seqlen are not supported yet" - ) - batch, seqlen, nheads, headdim = x.shape - seqlen_ro, rotary_dim = cos.shape - assert sin.shape == cos.shape - rotary_dim *= 2 - assert rotary_dim <= headdim, "rotary_dim must be <= headdim" - assert headdim <= 256, "Only support headdim <= 256" - assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen" - assert ( - cos.dtype == sin.dtype - ), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}" - assert ( - x.dtype == cos.dtype - ), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}" - - output = torch.empty_like(x) if not inplace else x - if rotary_dim < headdim and not inplace: - output[..., rotary_dim:].copy_(x[..., rotary_dim:]) - ext.apply_rotary( - output, - x, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - conjugate, - interleaved, - ) - return output diff --git a/deeplink_ext/internlm_ops/rotary/fallback.py b/deeplink_ext/internlm_ops/rotary/fallback.py deleted file mode 100644 index 7ce7c652..00000000 --- a/deeplink_ext/internlm_ops/rotary/fallback.py +++ /dev/null @@ -1,98 +0,0 @@ -# Copyright (c) 2024, DeepLink. - -from typing import Optional, Union -import torch -from einops import rearrange, repeat - -__all__ = ["apply_rotary"] - - -def _rotate_half(x: torch.Tensor, interleaved=False): - if not interleaved: - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - else: - x1, x2 = x[..., ::2], x[..., 1::2] - return rearrange( - torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2 - ) - - -def _apply_rotary_emb_torch( - x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved=False -): - """ - x: (batch_size, seqlen, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) - """ - - data_type = x.dtype - x = x.to(torch.float32) - cos = cos.to(torch.float32) - sin = sin.to(torch.float32) - - ro_dim = cos.shape[-1] * 2 - assert ro_dim <= x.shape[-1] - cos = repeat( - cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" - ) - sin = repeat( - sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" - ) - return torch.cat( - [ - x[..., :ro_dim] * cos + _rotate_half(x[..., :ro_dim], interleaved) * sin, - x[..., ro_dim:], - ], - dim=-1, - ).to(data_type) - - -def apply_rotary( - x: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - seqlen_offsets: Union[int, torch.Tensor] = 0, - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[int] = None, - interleaved=False, - inplace=False, - conjugate=False, -) -> torch.Tensor: - """ - Arguments: - x: (batch, seqlen, nheads, headdim) if cu_seqlens is None - else (total_seqlen, nheads, headdim). - cos: (seqlen_ro, rotary_dim / 2) - sin: (seqlen_ro, rotary_dim / 2) - seqlen_offsets: integer or integer tensor of size (batch,) - cu_seqlens: (batch + 1,) or None - max_seqlen: int - Returns: - y: (batch, seqlen, nheads, headdim) - """ - if seqlen_offsets != 0 and cu_seqlens is None and max_seqlen is None: - raise NotImplementedError( - "apply_rotary: seqlen_offsets, cu_seqlens and max_seqlen are not supported yet" - ) - batch, seqlen, nheads, headdim = x.shape - seqlen_ro, rotary_dim = cos.shape - assert sin.shape == cos.shape - rotary_dim *= 2 - assert rotary_dim <= headdim, "rotary_dim must be <= headdim" - assert headdim <= 256, "Only support headdim <= 256" - assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen" - assert ( - cos.dtype == sin.dtype - ), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}" - assert ( - x.dtype == cos.dtype - ), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}" - - if conjugate: - sin = -sin - out = _apply_rotary_emb_torch(x, cos[:seqlen], sin[:seqlen], interleaved) - if inplace: - x.copy_(out) - out = x - return out From 03d59920b7df523a4328df90fb230b561cbd2323 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 1 Apr 2024 05:12:09 +0000 Subject: [PATCH 32/39] modify rotary_embeding --- tests/test_rotary_emb_internlm.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/test_rotary_emb_internlm.py b/tests/test_rotary_emb_internlm.py index e212bd82..be08c22b 100644 --- a/tests/test_rotary_emb_internlm.py +++ b/tests/test_rotary_emb_internlm.py @@ -1,7 +1,8 @@ # Copyright (c) 2023, DeepLink. import torch -import deeplink_ext.internlm_ops.rotary as ext +from deeplink_ext.internlm_ops.rotary_embedding import apply_rotary +from deeplink_ext.internlm_ops.rotary_embeddinig_fallback import apply_rotary as apply_rotary_fb def RotaryEmbTestFloat16() -> bool: @@ -13,10 +14,10 @@ def RotaryEmbTestFloat16() -> bool: inplace = True interleaved = False - res1 = ext.fallback.apply_rotary( + res1 = apply_rotary_fb( input, cos, sin, interleaved=interleaved, inplace=inplace ) - res2 = ext.apply_rotary(input1, cos, sin, interleaved=interleaved, inplace=inplace) + res2 = apply_rotary(input1, cos, sin, interleaved=interleaved, inplace=inplace) # there is a little calculated error with ascend when dtype is float16 return torch.allclose(res1, res2, atol=1e-2, rtol=1e-3) @@ -31,10 +32,10 @@ def RotaryEmbTestFloat32() -> bool: inplace = True interleaved = False - res1 = ext.fallback.apply_rotary( + res1 = apply_rotary_fb( input, cos, sin, interleaved=interleaved, inplace=inplace ) - res2 = ext.apply_rotary(input1, cos, sin, interleaved=interleaved, inplace=inplace) + res2 = apply_rotary(input1, cos, sin, interleaved=interleaved, inplace=inplace) return torch.allclose(res1, res2) From 92bcf4656767cd2ee51ae80a23d71e1ce7d57403 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 1 Apr 2024 05:13:57 +0000 Subject: [PATCH 33/39] modify rotary_embeding --- deeplink_ext/internlm_ops/rotary_embedding.py | 65 ++++++++++++ .../rotary_embeddinig_fallback.py | 98 +++++++++++++++++++ 2 files changed, 163 insertions(+) create mode 100644 deeplink_ext/internlm_ops/rotary_embedding.py create mode 100644 deeplink_ext/internlm_ops/rotary_embeddinig_fallback.py diff --git a/deeplink_ext/internlm_ops/rotary_embedding.py b/deeplink_ext/internlm_ops/rotary_embedding.py new file mode 100644 index 00000000..670a47b9 --- /dev/null +++ b/deeplink_ext/internlm_ops/rotary_embedding.py @@ -0,0 +1,65 @@ +# Copyright (c) 2024, DeepLink. + +from typing import Optional, Union +import torch +from einops import rearrange +import deeplink_ext.cpp_extensions as ext + +assert hasattr(ext, "apply_rotary") + +__all__ = ["apply_rotary"] + + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + interleaved=False, + inplace=False, + conjugate=False, +) -> torch.Tensor: + """ + Arguments: + x: (batch, seqlen, nheads, headdim) if cu_seqlens is None + else (total_seqlen, nheads, headdim). + cos: (seqlen_ro, rotary_dim / 2) + sin: (seqlen_ro, rotary_dim / 2) + seqlen_offsets: integer or integer tensor of size (batch,) + cu_seqlens: (batch + 1,) or None + max_seqlen: int + Returns: + y: (batch, seqlen, nheads, headdim) + """ + if seqlen_offsets != 0 and cu_seqlens is None and max_seqlen is None: + raise NotImplementedError( + "apply_rotary: seqlen_offsets, cu_seqlens and max_seqlen are not supported yet" + ) + batch, seqlen, nheads, headdim = x.shape + seqlen_ro, rotary_dim = cos.shape + assert sin.shape == cos.shape + rotary_dim *= 2 + assert rotary_dim <= headdim, "rotary_dim must be <= headdim" + assert headdim <= 256, "Only support headdim <= 256" + assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen" + assert ( + cos.dtype == sin.dtype + ), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}" + assert ( + x.dtype == cos.dtype + ), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}" + + output = torch.empty_like(x) if not inplace else x + if rotary_dim < headdim and not inplace: + output[..., rotary_dim:].copy_(x[..., rotary_dim:]) + ext.apply_rotary( + output, + x, + rearrange(cos[:seqlen], "s d -> s 1 d"), + rearrange(sin[:seqlen], "s d -> s 1 d"), + conjugate, + interleaved, + ) + return output diff --git a/deeplink_ext/internlm_ops/rotary_embeddinig_fallback.py b/deeplink_ext/internlm_ops/rotary_embeddinig_fallback.py new file mode 100644 index 00000000..7ce7c652 --- /dev/null +++ b/deeplink_ext/internlm_ops/rotary_embeddinig_fallback.py @@ -0,0 +1,98 @@ +# Copyright (c) 2024, DeepLink. + +from typing import Optional, Union +import torch +from einops import rearrange, repeat + +__all__ = ["apply_rotary"] + + +def _rotate_half(x: torch.Tensor, interleaved=False): + if not interleaved: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + else: + x1, x2 = x[..., ::2], x[..., 1::2] + return rearrange( + torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2 + ) + + +def _apply_rotary_emb_torch( + x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved=False +): + """ + x: (batch_size, seqlen, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) + """ + + data_type = x.dtype + x = x.to(torch.float32) + cos = cos.to(torch.float32) + sin = sin.to(torch.float32) + + ro_dim = cos.shape[-1] * 2 + assert ro_dim <= x.shape[-1] + cos = repeat( + cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" + ) + sin = repeat( + sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" + ) + return torch.cat( + [ + x[..., :ro_dim] * cos + _rotate_half(x[..., :ro_dim], interleaved) * sin, + x[..., ro_dim:], + ], + dim=-1, + ).to(data_type) + + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + interleaved=False, + inplace=False, + conjugate=False, +) -> torch.Tensor: + """ + Arguments: + x: (batch, seqlen, nheads, headdim) if cu_seqlens is None + else (total_seqlen, nheads, headdim). + cos: (seqlen_ro, rotary_dim / 2) + sin: (seqlen_ro, rotary_dim / 2) + seqlen_offsets: integer or integer tensor of size (batch,) + cu_seqlens: (batch + 1,) or None + max_seqlen: int + Returns: + y: (batch, seqlen, nheads, headdim) + """ + if seqlen_offsets != 0 and cu_seqlens is None and max_seqlen is None: + raise NotImplementedError( + "apply_rotary: seqlen_offsets, cu_seqlens and max_seqlen are not supported yet" + ) + batch, seqlen, nheads, headdim = x.shape + seqlen_ro, rotary_dim = cos.shape + assert sin.shape == cos.shape + rotary_dim *= 2 + assert rotary_dim <= headdim, "rotary_dim must be <= headdim" + assert headdim <= 256, "Only support headdim <= 256" + assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen" + assert ( + cos.dtype == sin.dtype + ), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}" + assert ( + x.dtype == cos.dtype + ), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}" + + if conjugate: + sin = -sin + out = _apply_rotary_emb_torch(x, cos[:seqlen], sin[:seqlen], interleaved) + if inplace: + x.copy_(out) + out = x + return out From b0965fb7b1feba6b521f39cb4770183ec68aa8c1 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 1 Apr 2024 05:14:32 +0000 Subject: [PATCH 34/39] lint --- deeplink_ext/internlm_ops/__init__.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/deeplink_ext/internlm_ops/__init__.py b/deeplink_ext/internlm_ops/__init__.py index eb2f9acd..57a7e0f1 100644 --- a/deeplink_ext/internlm_ops/__init__.py +++ b/deeplink_ext/internlm_ops/__init__.py @@ -21,9 +21,8 @@ try: from .rotary_embedding import apply_rotary except: - print( _not_impl.format("apply_rotary")) + print(_not_impl.format("apply_rotary")) from .rotary_embeddinig_fallback import apply_rotary - -__all__ = ["mha", "RMSNorm", "RMSNormWithNormalizedShape", "apply_rotary"] \ No newline at end of file +__all__ = ["mha", "RMSNorm", "RMSNormWithNormalizedShape", "apply_rotary"] From e979201c0973f62bcd649e7c86a9dce938f3a5ae Mon Sep 17 00:00:00 2001 From: root Date: Mon, 1 Apr 2024 05:18:02 +0000 Subject: [PATCH 35/39] fix --- deeplink_ext/internlm_ops/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/deeplink_ext/internlm_ops/__init__.py b/deeplink_ext/internlm_ops/__init__.py index 57a7e0f1..425fb962 100644 --- a/deeplink_ext/internlm_ops/__init__.py +++ b/deeplink_ext/internlm_ops/__init__.py @@ -3,14 +3,14 @@ from . import mha -_not_impl = "[deeplink_ext] %s is not implemented in diopi. Falling back to the slower torch implementation." +_not_impl = "[deeplink_ext] {op_name} is not implemented in diopi. Falling back to the slower torch implementation." try: from .rms_norm import RMSNorm, RMSNormWithNormalizedShape except: print( - _not_impl.format("RMSNorm or RMSNormWithNormalizedShape"), + _not_impl.format(op_name="RMSNorm or RMSNormWithNormalizedShape"), ) from .rms_norm_fallback import ( RMSNorm as RMSNorm, @@ -21,7 +21,7 @@ try: from .rotary_embedding import apply_rotary except: - print(_not_impl.format("apply_rotary")) + print(_not_impl.format(op_name="apply_rotary")) from .rotary_embeddinig_fallback import apply_rotary From a3cd9daaeb954a64044238e5c16b27b19a5e9b77 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 1 Apr 2024 05:24:31 +0000 Subject: [PATCH 36/39] fix --- deeplink_ext/internlm_ops/rms_norm.py | 3 +++ deeplink_ext/internlm_ops/rms_norm_fallback.py | 4 ++++ deeplink_ext/patch_internlm.py | 4 ++-- tests/test_rms_internlm.py | 4 ++-- 4 files changed, 11 insertions(+), 4 deletions(-) diff --git a/deeplink_ext/internlm_ops/rms_norm.py b/deeplink_ext/internlm_ops/rms_norm.py index 89f5cb44..0595f269 100644 --- a/deeplink_ext/internlm_ops/rms_norm.py +++ b/deeplink_ext/internlm_ops/rms_norm.py @@ -4,6 +4,9 @@ from DeepLinkExt.deeplink_ext.common.rms_norm import rms_norm, rms_norm_backward +__all__ = ["RMSNorm", "RMSNormWithNormalizedShape"] + + # 定义自定义的 autograd 函数 class _DeepLinkRMSNormFunction(torch.autograd.Function): @staticmethod diff --git a/deeplink_ext/internlm_ops/rms_norm_fallback.py b/deeplink_ext/internlm_ops/rms_norm_fallback.py index e58a83e7..806754c5 100644 --- a/deeplink_ext/internlm_ops/rms_norm_fallback.py +++ b/deeplink_ext/internlm_ops/rms_norm_fallback.py @@ -2,6 +2,8 @@ import torch +__all__ = ["RMSNorm", "RMSNormWithNormalizedShape"] + # RMSNorm fallback from InternLM class RMSNorm(torch.nn.Module): @@ -22,3 +24,5 @@ def forward(self, hidden_states): hidden_states = hidden_states.to(self.weight.dtype) return self.weight * hidden_states + +RMSNormWithNormalizedShape = RMSNorm diff --git a/deeplink_ext/patch_internlm.py b/deeplink_ext/patch_internlm.py index f06d7eb3..a897fb52 100644 --- a/deeplink_ext/patch_internlm.py +++ b/deeplink_ext/patch_internlm.py @@ -115,7 +115,7 @@ def backward(ctx, dqkv: torch.Tensor, *args, **kwargs): # type: ignore # if isinstance(module, RMSNorm): # which will fail under this patch. Thus we need also trick `isinstance`. internlm.model.norm.RMSNormTorch.__new__ = lambda _, *args, **kwargs: ( - ext.rms_norm.DeepLinkRMSNormWithNormalizedShape(*args, **kwargs) + ext.rms_norm.RMSNormWithNormalizedShape(*args, **kwargs) ) isinstance_orig = builtins.isinstance builtins.isinstance = lambda obj, class_or_tuple: ( @@ -130,7 +130,7 @@ def backward(ctx, dqkv: torch.Tensor, *args, **kwargs): # type: ignore ) ) and isinstance_orig( - obj, ext.rms_norm.DeepLinkRMSNormWithNormalizedShape + obj, ext.rms_norm.RMSNormWithNormalizedShape ) ) ) diff --git a/tests/test_rms_internlm.py b/tests/test_rms_internlm.py index 3dde74de..5ff0dd9b 100644 --- a/tests/test_rms_internlm.py +++ b/tests/test_rms_internlm.py @@ -3,7 +3,7 @@ import torch import numpy as np from deeplink_ext.internlm_ops.rms_norm import RMSNorm, RMSNormWithNormalizedShape -from deeplink_ext.internlm_ops.rms_norm_fallback import RMSNorm as RMSNorm_fb +from deeplink_ext.internlm_ops.rms_norm_fallback import RMSNorm as RMSNorm_fb, RMSNormWithNormalizedShape as RMSNormWithNormalizedShape_fb def rms_norm_test(BaseRmsNorm, DeeplinkRmsNorm, rtol=1e-4, atol=1e-3): @@ -34,5 +34,5 @@ def rms_norm_test(BaseRmsNorm, DeeplinkRmsNorm, rtol=1e-4, atol=1e-3): ) print( "Test case: normalized_shape == weight.size(): grad_inputs closed ? ", - rms_norm_test(RMSNorm_fb, RMSNormWithNormalizedShape), + rms_norm_test(RMSNormWithNormalizedShape_fb, RMSNormWithNormalizedShape), ) From ca0b33c3018022a71597dc749e689324b3d23ae3 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 1 Apr 2024 07:25:03 +0000 Subject: [PATCH 37/39] modify mha --- deeplink_ext/internlm_ops/__init__.py | 18 +- deeplink_ext/internlm_ops/mha.py | 473 ++++++++++++++++++ deeplink_ext/internlm_ops/mha/__init__.py | 16 - .../internlm_ops/mha/fallback/__init__.py | 5 - deeplink_ext/internlm_ops/mha/mha.py | 109 ---- deeplink_ext/internlm_ops/mha/mha_func.py | 49 -- .../internlm_ops/mha/mha_kvpacked_func.py | 51 -- .../internlm_ops/mha/mha_qkvpacked_func.py | 50 -- .../internlm_ops/mha/mha_varlen_func.py | 83 --- .../mha/mha_varlen_kvpacked_func.py | 83 --- .../mha/mha_varlen_qkvpacked_func.py | 68 --- .../fallback/fallback.py => mha_fallback.py} | 2 + .../internlm_ops/rms_norm_fallback.py | 1 + deeplink_ext/patch_internlm.py | 8 +- tests/test_mha_internlm.py | 4 +- 15 files changed, 497 insertions(+), 523 deletions(-) create mode 100644 deeplink_ext/internlm_ops/mha.py delete mode 100644 deeplink_ext/internlm_ops/mha/__init__.py delete mode 100644 deeplink_ext/internlm_ops/mha/fallback/__init__.py delete mode 100644 deeplink_ext/internlm_ops/mha/mha.py delete mode 100644 deeplink_ext/internlm_ops/mha/mha_func.py delete mode 100644 deeplink_ext/internlm_ops/mha/mha_kvpacked_func.py delete mode 100644 deeplink_ext/internlm_ops/mha/mha_qkvpacked_func.py delete mode 100644 deeplink_ext/internlm_ops/mha/mha_varlen_func.py delete mode 100644 deeplink_ext/internlm_ops/mha/mha_varlen_kvpacked_func.py delete mode 100644 deeplink_ext/internlm_ops/mha/mha_varlen_qkvpacked_func.py rename deeplink_ext/internlm_ops/{mha/fallback/fallback.py => mha_fallback.py} (98%) diff --git a/deeplink_ext/internlm_ops/__init__.py b/deeplink_ext/internlm_ops/__init__.py index 425fb962..61f616ad 100644 --- a/deeplink_ext/internlm_ops/__init__.py +++ b/deeplink_ext/internlm_ops/__init__.py @@ -13,8 +13,8 @@ _not_impl.format(op_name="RMSNorm or RMSNormWithNormalizedShape"), ) from .rms_norm_fallback import ( - RMSNorm as RMSNorm, - RMSNorm as RMSNormWithNormalizedShape, + RMSNorm, + RMSNormWithNormalizedShape, ) @@ -25,4 +25,16 @@ from .rotary_embeddinig_fallback import apply_rotary -__all__ = ["mha", "RMSNorm", "RMSNormWithNormalizedShape", "apply_rotary"] +try: + from .mha import SelfAttention, CrossAttention +except Exception as e: + print(_not_impl.format(op_name="mha")) + from .mha_fallback import SelfAttention, CrossAttention + +__all__ = [ + "SelfAttention", + "CrossAttention", + "RMSNorm", + "RMSNormWithNormalizedShape", + "apply_rotary", +] diff --git a/deeplink_ext/internlm_ops/mha.py b/deeplink_ext/internlm_ops/mha.py new file mode 100644 index 00000000..18b8af3d --- /dev/null +++ b/deeplink_ext/internlm_ops/mha.py @@ -0,0 +1,473 @@ +# Copyright (c) 2023, DeepLink. + +import torch +import deeplink_ext.cpp_extensions as ext + +import torch.nn as nn + + +__all__ = [ + "MultiHeadAttention", + "MultiHeadAttentionKVPacked", + "MultiHeadAttentionQKVPacked", + "MultiHeadAttentionVarLen", + "MultiHeadAttentionVarLenKVPacked", + "MultiHeadAttentionVarLenQKVPacked", + "SelfAttention", + "CrossAttention", +] + +assert hasattr(ext, "mha_fwd") and hasattr(ext, "mha_bwd") + + +class MultiHeadAttention(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, return_softmax): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + out, softmax_lse, rng, S_dmask = ext.mha_fwd( + q, + k, + v, + dropout_p, + causal, + return_softmax and dropout_p > 0, + softmax_scale, + ) + ctx.save_for_backward(q, k, v, out, softmax_lse, rng.get_state()) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout): + q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors + rng = torch.Generator(device=q.device) + rng.set_state(rng_state) + dq, dk, dv = ext.mha_bwd( + dout, + q, + k, + v, + out, + softmax_lse, + ctx.dropout_p, + ctx.causal, + rng, + ctx.softmax_scale, + None, + None, + None, + ) + return dq, dk, dv, None, None, None, None + + +class MultiHeadAttentionKVPacked(torch.autograd.Function): + @staticmethod + def forward(ctx, q, kv, dropout_p, softmax_scale, causal, return_softmax): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + out, softmax_lse, rng, S_dmask = ext.mha_fwd( + q, + kv[:, :, 0], + kv[:, :, 1], + dropout_p, + causal, + return_softmax and dropout_p > 0, + softmax_scale, + ) + ctx.save_for_backward(q, kv, out, softmax_lse, rng.get_state()) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout): + q, kv, out, softmax_lse, rng_state = ctx.saved_tensors + dq = torch.empty_like(q) + dkv = torch.empty_like(kv) + rng = torch.Generator(device=q.device) + rng.set_state(rng_state) + ext.mha_bwd( + dout, + q, + kv[:, :, 0], + kv[:, :, 1], + out, + softmax_lse, + ctx.dropout_p, + ctx.causal, + rng, + ctx.softmax_scale, + dq, + dkv[:, :, 0], + dkv[:, :, 1], + ) + return dq, dkv, None, None, None, None + + +class MultiHeadAttentionQKVPacked(torch.autograd.Function): + @staticmethod + def forward(ctx, qkv, dropout_p, softmax_scale, causal, return_softmax): + if softmax_scale is None: + softmax_scale = qkv.shape[-1] ** (-0.5) + out, softmax_lse, rng, S_dmask = ext.mha_fwd( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + dropout_p, + causal, + return_softmax and dropout_p > 0, + softmax_scale, + ) + ctx.save_for_backward(qkv, out, softmax_lse, rng.get_state()) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout): + qkv, out, softmax_lse, rng_state = ctx.saved_tensors + dqkv = torch.empty_like(qkv) + rng = torch.Generator(device=qkv.device) + rng.set_state(rng_state) + ext.mha_bwd( + dout, + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + out, + softmax_lse, + ctx.dropout_p, + ctx.causal, + rng, + ctx.softmax_scale, + dqkv[:, :, 0], + dqkv[:, :, 1], + dqkv[:, :, 2], + ) + return dqkv, None, None, None, None + + +assert hasattr(ext, "mha_varlen_fwd") and hasattr(ext, "mha_varlen_bwd") + + +class MultiHeadAttentionVarLen(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + return_softmax, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + out, softmax_lse, rng, S_dmask = ext.mha_varlen_fwd( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + causal, + return_softmax and dropout_p > 0, + softmax_scale, + ) + ctx.save_for_backward( + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng.get_state() + ) + ctx.dropout_p = dropout_p + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.softmax_scale = softmax_scale + ctx.causal = causal + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout): + ( + q, + k, + v, + out, + softmax_lse, + cu_seqlens_q, + cu_seqlens_k, + rng_state, + ) = ctx.saved_tensors + rng = torch.Generator(device=q.device) + rng.set_state(rng_state) + dq, dk, dv = ext.mha_varlen_bwd( + dout, + q, + k, + v, + out, + softmax_lse, + cu_seqlens_q, + cu_seqlens_k, + ctx.max_seqlen_q, + ctx.max_seqlen_k, + ctx.dropout_p, + ctx.causal, + rng, + ctx.softmax_scale, + None, + None, + None, + ) + return dq, dk, dv, None, None, None, None, None, None, None, None + + +class MultiHeadAttentionVarLenKVPacked(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + kv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + return_softmax, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + out, softmax_lse, rng, S_dmask = ext.mha_varlen_fwd( + q, + kv[:, 0], + kv[:, 1], + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + causal, + return_softmax and dropout_p > 0, + softmax_scale, + ) + ctx.save_for_backward( + q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng.get_state() + ) + ctx.dropout_p = dropout_p + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.softmax_scale = softmax_scale + ctx.causal = causal + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout): + ( + q, + kv, + out, + softmax_lse, + cu_seqlens_q, + cu_seqlens_k, + rng_state, + ) = ctx.saved_tensors + dq = torch.empty_like(q) + dkv = torch.empty_like(kv) + rng = torch.Generator(device=q.device) + rng.set_state(rng_state) + ext.mha_varlen_bwd( + dout, + q, + kv[:, 0], + kv[:, 1], + out, + softmax_lse, + cu_seqlens_q, + cu_seqlens_k, + ctx.max_seqlen_q, + ctx.max_seqlen_k, + ctx.dropout_p, + ctx.causal, + rng, + ctx.softmax_scale, + dq, + dkv[:, 0], + dkv[:, 1], + ) + return dq, dkv, None, None, None, None, None, None, None, None + + +class MultiHeadAttentionVarLenQKVPacked(torch.autograd.Function): + @staticmethod + def forward( + ctx, + qkv, + cu_seqlens, + max_seqlen, + dropout_p, + softmax_scale, + causal, + return_softmax, + ): + if softmax_scale is None: + softmax_scale = qkv.shape[-1] ** (-0.5) + out, softmax_lse, rng, S_dmask = ext.mha_varlen_fwd( + qkv[:, 0], + qkv[:, 1], + qkv[:, 2], + cu_seqlens, + cu_seqlens, + max_seqlen, + max_seqlen, + dropout_p, + causal, + return_softmax and dropout_p > 0, + softmax_scale, + ) + ctx.save_for_backward(qkv, out, softmax_lse, cu_seqlens, rng.get_state()) + ctx.dropout_p = dropout_p + ctx.max_seqlen = max_seqlen + ctx.softmax_scale = softmax_scale + ctx.causal = causal + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout): + qkv, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors + dqkv = torch.empty_like(qkv) + rng = torch.Generator(device=qkv.device) + rng.set_state(rng_state) + ext.mha_varlen_bwd( + dout, + qkv[:, 0], + qkv[:, 1], + qkv[:, 2], + out, + softmax_lse, + cu_seqlens, + cu_seqlens, + ctx.max_seqlen, + ctx.max_seqlen, + ctx.dropout_p, + ctx.causal, + rng, + ctx.softmax_scale, + dqkv[:, 0], + dqkv[:, 1], + dqkv[:, 2], + ) + return dqkv, None, None, None, None, None, None, None, None + + +class SelfAttention(nn.Module): + """Performs self-attention with support for both padded and unpadded sequences. + + Args: + causal (bool, optional): If True, applies causal self-attention, meaning each + position can only attend to previous positions. Default is False. + softmax_scale (float, optional): Scaling factor applied to the softmax + operation. If not provided, will be D^{-0.5}. Default is None. + dropout_p (float, optional): Dropout probability applied to the attention + scores. Default is 0.0. + """ + + def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): + super().__init__() + self.causal = causal + self.softmax_scale = softmax_scale + self.dropout_p = attention_dropout + + def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None): + """Performs self-attention on the input sequences. + + Args: + qkv (torch.Tensor): Input tensor representing queries, keys, and values + concatenated together. (B, S, 3, H, D) for padded; (total, 3, H, D) + for unpadded. + causal (bool, optional): If provided, overrides the class-level 'causal' + argument for this forward pass. Default is None. + cu_seqlens (torch.Tensor((batch_size + 1,), dtype=torch.int32), optional): + Sequence lengths tensor for unpadded sequences. If provided, performs + attention on unpadded sequences. Default is None. + max_seqlen (int, optional): Maximum sequence length for unpadded sequences. + If provided, defines the maximum length of the sequences. Default is + None. + + Returns: + torch.Tensor: Output tensor after applying self-attention. + """ + if cu_seqlens is None: + # padded + return MultiHeadAttentionQKVPacked.apply( + qkv, + self.dropout_p if self.training else 0.0, + self.softmax_scale, + causal if causal is not None else self.causal, + False, + ) + else: + # unpadded + return MultiHeadAttentionVarLenQKVPacked.apply( + qkv, + cu_seqlens, + max_seqlen, + self.dropout_p if self.training else 0.0, + self.softmax_scale, + causal if causal is not None else self.causal, + False, + ) + + +class CrossAttention(nn.Module): + def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): + super().__init__() + self.causal = causal + self.softmax_scale = softmax_scale + self.dropout_p = attention_dropout + + def forward( + self, + q, + kv, + causal=None, + cu_seqlens=None, + max_seqlen=None, + cu_seqlens_k=None, + max_seqlen_k=None, + ): + if cu_seqlens is None: + # padded + return MultiHeadAttentionKVPacked.apply( + q, + kv, + self.dropout_p if self.training else 0.0, + self.softmax_scale, + causal if causal is not None else self.causal, + False, + ) + else: + # unpadded + return MultiHeadAttentionVarLenKVPacked.apply( + q, + kv, + cu_seqlens, + cu_seqlens_k, + max_seqlen, + max_seqlen_k, + self.dropout_p if self.training else 0.0, + self.softmax_scale, + causal if causal is not None else self.causal, + False, + ) diff --git a/deeplink_ext/internlm_ops/mha/__init__.py b/deeplink_ext/internlm_ops/mha/__init__.py deleted file mode 100644 index 212ddfd9..00000000 --- a/deeplink_ext/internlm_ops/mha/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright (c) 2023, DeepLink. - -try: - from .mha import DeepLinkSelfAttention, DeepLinkCrossAttention -except Exception as e: - print( - "[deeplink_ext] mha is not implemented in diopi. Falling back to the slower implementation.\n", - end="", - ) - from .fallback import ( - SelfAttention as DeepLinkSelfAttention, - CrossAttention as DeepLinkCrossAttention, - ) -from . import fallback - -__all__ = ["DeepLinkSelfAttention", "DeepLinkCrossAttention", "fallback"] diff --git a/deeplink_ext/internlm_ops/mha/fallback/__init__.py b/deeplink_ext/internlm_ops/mha/fallback/__init__.py deleted file mode 100644 index 8f12c7d4..00000000 --- a/deeplink_ext/internlm_ops/mha/fallback/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) 2024, DeepLink. - -from .fallback import SelfAttention, CrossAttention - -__all__ = ["SelfAttention", "CrossAttention"] diff --git a/deeplink_ext/internlm_ops/mha/mha.py b/deeplink_ext/internlm_ops/mha/mha.py deleted file mode 100644 index 00798027..00000000 --- a/deeplink_ext/internlm_ops/mha/mha.py +++ /dev/null @@ -1,109 +0,0 @@ -# Copyright (c) 2023, DeepLink. - -import torch.nn as nn -from .mha_qkvpacked_func import DeepLinkMultiHeadAttentionQKVPackedFunc -from .mha_varlen_qkvpacked_func import DeepLinkMultiHeadAttentionVarLenQKVPackedFunc -from .mha_kvpacked_func import DeepLinkMultiHeadAttentionKVPackedFunc -from .mha_varlen_kvpacked_func import DeepLinkMultiHeadAttentionVarLenKVPackedFunc - - -class DeepLinkSelfAttention(nn.Module): - """Performs self-attention with support for both padded and unpadded sequences. - - Args: - causal (bool, optional): If True, applies causal self-attention, meaning each - position can only attend to previous positions. Default is False. - softmax_scale (float, optional): Scaling factor applied to the softmax - operation. If not provided, will be D^{-0.5}. Default is None. - dropout_p (float, optional): Dropout probability applied to the attention - scores. Default is 0.0. - """ - - def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): - super().__init__() - self.causal = causal - self.softmax_scale = softmax_scale - self.dropout_p = attention_dropout - - def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None): - """Performs self-attention on the input sequences. - - Args: - qkv (torch.Tensor): Input tensor representing queries, keys, and values - concatenated together. (B, S, 3, H, D) for padded; (total, 3, H, D) - for unpadded. - causal (bool, optional): If provided, overrides the class-level 'causal' - argument for this forward pass. Default is None. - cu_seqlens (torch.Tensor((batch_size + 1,), dtype=torch.int32), optional): - Sequence lengths tensor for unpadded sequences. If provided, performs - attention on unpadded sequences. Default is None. - max_seqlen (int, optional): Maximum sequence length for unpadded sequences. - If provided, defines the maximum length of the sequences. Default is - None. - - Returns: - torch.Tensor: Output tensor after applying self-attention. - """ - if cu_seqlens is None: - # padded - return DeepLinkMultiHeadAttentionQKVPackedFunc.apply( - qkv, - self.dropout_p if self.training else 0.0, - self.softmax_scale, - causal if causal is not None else self.causal, - False, - ) - else: - # unpadded - return DeepLinkMultiHeadAttentionVarLenQKVPackedFunc.apply( - qkv, - cu_seqlens, - max_seqlen, - self.dropout_p if self.training else 0.0, - self.softmax_scale, - causal if causal is not None else self.causal, - False, - ) - - -class DeepLinkCrossAttention(nn.Module): - def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): - super().__init__() - self.causal = causal - self.softmax_scale = softmax_scale - self.dropout_p = attention_dropout - - def forward( - self, - q, - kv, - causal=None, - cu_seqlens=None, - max_seqlen=None, - cu_seqlens_k=None, - max_seqlen_k=None, - ): - if cu_seqlens is None: - # padded - return DeepLinkMultiHeadAttentionKVPackedFunc.apply( - q, - kv, - self.dropout_p if self.training else 0.0, - self.softmax_scale, - causal if causal is not None else self.causal, - False, - ) - else: - # unpadded - return DeepLinkMultiHeadAttentionVarLenKVPackedFunc.apply( - q, - kv, - cu_seqlens, - cu_seqlens_k, - max_seqlen, - max_seqlen_k, - self.dropout_p if self.training else 0.0, - self.softmax_scale, - causal if causal is not None else self.causal, - False, - ) diff --git a/deeplink_ext/internlm_ops/mha/mha_func.py b/deeplink_ext/internlm_ops/mha/mha_func.py deleted file mode 100644 index 3efecb5d..00000000 --- a/deeplink_ext/internlm_ops/mha/mha_func.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright (c) 2023, DeepLink. - -import torch -import deeplink_ext.cpp_extensions as ext - -assert hasattr(ext, "mha_fwd") and hasattr(ext, "mha_bwd") - - -class DeepLinkMultiHeadAttentionFunc(torch.autograd.Function): - @staticmethod - def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, return_softmax): - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - out, softmax_lse, rng, S_dmask = ext.mha_fwd( - q, - k, - v, - dropout_p, - causal, - return_softmax and dropout_p > 0, - softmax_scale, - ) - ctx.save_for_backward(q, k, v, out, softmax_lse, rng.get_state()) - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout): - q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors - rng = torch.Generator(device=q.device) - rng.set_state(rng_state) - dq, dk, dv = ext.mha_bwd( - dout, - q, - k, - v, - out, - softmax_lse, - ctx.dropout_p, - ctx.causal, - rng, - ctx.softmax_scale, - None, - None, - None, - ) - return dq, dk, dv, None, None, None, None diff --git a/deeplink_ext/internlm_ops/mha/mha_kvpacked_func.py b/deeplink_ext/internlm_ops/mha/mha_kvpacked_func.py deleted file mode 100644 index 33e248f1..00000000 --- a/deeplink_ext/internlm_ops/mha/mha_kvpacked_func.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright (c) 2023, DeepLink. - -import torch -import deeplink_ext.cpp_extensions as ext - -assert hasattr(ext, "mha_fwd") and hasattr(ext, "mha_bwd") - - -class DeepLinkMultiHeadAttentionKVPackedFunc(torch.autograd.Function): - @staticmethod - def forward(ctx, q, kv, dropout_p, softmax_scale, causal, return_softmax): - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - out, softmax_lse, rng, S_dmask = ext.mha_fwd( - q, - kv[:, :, 0], - kv[:, :, 1], - dropout_p, - causal, - return_softmax and dropout_p > 0, - softmax_scale, - ) - ctx.save_for_backward(q, kv, out, softmax_lse, rng.get_state()) - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout): - q, kv, out, softmax_lse, rng_state = ctx.saved_tensors - dq = torch.empty_like(q) - dkv = torch.empty_like(kv) - rng = torch.Generator(device=q.device) - rng.set_state(rng_state) - ext.mha_bwd( - dout, - q, - kv[:, :, 0], - kv[:, :, 1], - out, - softmax_lse, - ctx.dropout_p, - ctx.causal, - rng, - ctx.softmax_scale, - dq, - dkv[:, :, 0], - dkv[:, :, 1], - ) - return dq, dkv, None, None, None, None diff --git a/deeplink_ext/internlm_ops/mha/mha_qkvpacked_func.py b/deeplink_ext/internlm_ops/mha/mha_qkvpacked_func.py deleted file mode 100644 index 61527adb..00000000 --- a/deeplink_ext/internlm_ops/mha/mha_qkvpacked_func.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright (c) 2023, DeepLink. - -import torch -import deeplink_ext.cpp_extensions as ext - -assert hasattr(ext, "mha_fwd") and hasattr(ext, "mha_bwd") - - -class DeepLinkMultiHeadAttentionQKVPackedFunc(torch.autograd.Function): - @staticmethod - def forward(ctx, qkv, dropout_p, softmax_scale, causal, return_softmax): - if softmax_scale is None: - softmax_scale = qkv.shape[-1] ** (-0.5) - out, softmax_lse, rng, S_dmask = ext.mha_fwd( - qkv[:, :, 0], - qkv[:, :, 1], - qkv[:, :, 2], - dropout_p, - causal, - return_softmax and dropout_p > 0, - softmax_scale, - ) - ctx.save_for_backward(qkv, out, softmax_lse, rng.get_state()) - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout): - qkv, out, softmax_lse, rng_state = ctx.saved_tensors - dqkv = torch.empty_like(qkv) - rng = torch.Generator(device=qkv.device) - rng.set_state(rng_state) - ext.mha_bwd( - dout, - qkv[:, :, 0], - qkv[:, :, 1], - qkv[:, :, 2], - out, - softmax_lse, - ctx.dropout_p, - ctx.causal, - rng, - ctx.softmax_scale, - dqkv[:, :, 0], - dqkv[:, :, 1], - dqkv[:, :, 2], - ) - return dqkv, None, None, None, None diff --git a/deeplink_ext/internlm_ops/mha/mha_varlen_func.py b/deeplink_ext/internlm_ops/mha/mha_varlen_func.py deleted file mode 100644 index 194a458d..00000000 --- a/deeplink_ext/internlm_ops/mha/mha_varlen_func.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright (c) 2023, DeepLink. - -import torch -import deeplink_ext.cpp_extensions as ext - -assert hasattr(ext, "mha_varlen_fwd") and hasattr(ext, "mha_varlen_bwd") - - -class DeepLinkMultiHeadAttentionVarLenFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal, - return_softmax, - ): - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - out, softmax_lse, rng, S_dmask = ext.mha_varlen_fwd( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - causal, - return_softmax and dropout_p > 0, - softmax_scale, - ) - ctx.save_for_backward( - q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng.get_state() - ) - ctx.dropout_p = dropout_p - ctx.max_seqlen_q = max_seqlen_q - ctx.max_seqlen_k = max_seqlen_k - ctx.softmax_scale = softmax_scale - ctx.causal = causal - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout): - ( - q, - k, - v, - out, - softmax_lse, - cu_seqlens_q, - cu_seqlens_k, - rng_state, - ) = ctx.saved_tensors - rng = torch.Generator(device=q.device) - rng.set_state(rng_state) - dq, dk, dv = ext.mha_varlen_bwd( - dout, - q, - k, - v, - out, - softmax_lse, - cu_seqlens_q, - cu_seqlens_k, - ctx.max_seqlen_q, - ctx.max_seqlen_k, - ctx.dropout_p, - ctx.causal, - rng, - ctx.softmax_scale, - None, - None, - None, - ) - return dq, dk, dv, None, None, None, None, None, None, None, None diff --git a/deeplink_ext/internlm_ops/mha/mha_varlen_kvpacked_func.py b/deeplink_ext/internlm_ops/mha/mha_varlen_kvpacked_func.py deleted file mode 100644 index 18569def..00000000 --- a/deeplink_ext/internlm_ops/mha/mha_varlen_kvpacked_func.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright (c) 2023, DeepLink. - -import torch -import deeplink_ext.cpp_extensions as ext - -assert hasattr(ext, "mha_varlen_fwd") and hasattr(ext, "mha_varlen_bwd") - - -class DeepLinkMultiHeadAttentionVarLenKVPackedFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - kv, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal, - return_softmax, - ): - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - out, softmax_lse, rng, S_dmask = ext.mha_varlen_fwd( - q, - kv[:, 0], - kv[:, 1], - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - causal, - return_softmax and dropout_p > 0, - softmax_scale, - ) - ctx.save_for_backward( - q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng.get_state() - ) - ctx.dropout_p = dropout_p - ctx.max_seqlen_q = max_seqlen_q - ctx.max_seqlen_k = max_seqlen_k - ctx.softmax_scale = softmax_scale - ctx.causal = causal - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout): - ( - q, - kv, - out, - softmax_lse, - cu_seqlens_q, - cu_seqlens_k, - rng_state, - ) = ctx.saved_tensors - dq = torch.empty_like(q) - dkv = torch.empty_like(kv) - rng = torch.Generator(device=q.device) - rng.set_state(rng_state) - ext.mha_varlen_bwd( - dout, - q, - kv[:, 0], - kv[:, 1], - out, - softmax_lse, - cu_seqlens_q, - cu_seqlens_k, - ctx.max_seqlen_q, - ctx.max_seqlen_k, - ctx.dropout_p, - ctx.causal, - rng, - ctx.softmax_scale, - dq, - dkv[:, 0], - dkv[:, 1], - ) - return dq, dkv, None, None, None, None, None, None, None, None diff --git a/deeplink_ext/internlm_ops/mha/mha_varlen_qkvpacked_func.py b/deeplink_ext/internlm_ops/mha/mha_varlen_qkvpacked_func.py deleted file mode 100644 index 562d0047..00000000 --- a/deeplink_ext/internlm_ops/mha/mha_varlen_qkvpacked_func.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright (c) 2023, DeepLink. - -import torch -import deeplink_ext.cpp_extensions as ext - -assert hasattr(ext, "mha_varlen_fwd") and hasattr(ext, "mha_varlen_bwd") - - -class DeepLinkMultiHeadAttentionVarLenQKVPackedFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - qkv, - cu_seqlens, - max_seqlen, - dropout_p, - softmax_scale, - causal, - return_softmax, - ): - if softmax_scale is None: - softmax_scale = qkv.shape[-1] ** (-0.5) - out, softmax_lse, rng, S_dmask = ext.mha_varlen_fwd( - qkv[:, 0], - qkv[:, 1], - qkv[:, 2], - cu_seqlens, - cu_seqlens, - max_seqlen, - max_seqlen, - dropout_p, - causal, - return_softmax and dropout_p > 0, - softmax_scale, - ) - ctx.save_for_backward(qkv, out, softmax_lse, cu_seqlens, rng.get_state()) - ctx.dropout_p = dropout_p - ctx.max_seqlen = max_seqlen - ctx.softmax_scale = softmax_scale - ctx.causal = causal - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout): - qkv, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors - dqkv = torch.empty_like(qkv) - rng = torch.Generator(device=qkv.device) - rng.set_state(rng_state) - ext.mha_varlen_bwd( - dout, - qkv[:, 0], - qkv[:, 1], - qkv[:, 2], - out, - softmax_lse, - cu_seqlens, - cu_seqlens, - ctx.max_seqlen, - ctx.max_seqlen, - ctx.dropout_p, - ctx.causal, - rng, - ctx.softmax_scale, - dqkv[:, 0], - dqkv[:, 1], - dqkv[:, 2], - ) - return dqkv, None, None, None, None, None, None, None, None diff --git a/deeplink_ext/internlm_ops/mha/fallback/fallback.py b/deeplink_ext/internlm_ops/mha_fallback.py similarity index 98% rename from deeplink_ext/internlm_ops/mha/fallback/fallback.py rename to deeplink_ext/internlm_ops/mha_fallback.py index 9c0a4c90..b14de68c 100644 --- a/deeplink_ext/internlm_ops/mha/fallback/fallback.py +++ b/deeplink_ext/internlm_ops/mha_fallback.py @@ -4,6 +4,8 @@ import torch.nn as nn import einops +__all__ = ["SelfAttention", "CrossAttention"] + class SelfAttention(nn.Module): def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): diff --git a/deeplink_ext/internlm_ops/rms_norm_fallback.py b/deeplink_ext/internlm_ops/rms_norm_fallback.py index 806754c5..6cd64fcb 100644 --- a/deeplink_ext/internlm_ops/rms_norm_fallback.py +++ b/deeplink_ext/internlm_ops/rms_norm_fallback.py @@ -25,4 +25,5 @@ def forward(self, hidden_states): return self.weight * hidden_states + RMSNormWithNormalizedShape = RMSNorm diff --git a/deeplink_ext/patch_internlm.py b/deeplink_ext/patch_internlm.py index a897fb52..2da73528 100644 --- a/deeplink_ext/patch_internlm.py +++ b/deeplink_ext/patch_internlm.py @@ -66,10 +66,10 @@ def CrossEntropyLossProxy(reduction, **_): import flash_attn.modules.mha # type: ignore - flash_attn.modules.mha.SelfAttention = ext.mha.DeepLinkSelfAttention - flash_attn.modules.mha.FlashSelfAttention = ext.mha.DeepLinkSelfAttention - flash_attn.modules.mha.CrossAttention = ext.mha.DeepLinkCrossAttention - flash_attn.modules.mha.FlashCrossAttention = ext.mha.DeepLinkCrossAttention + flash_attn.modules.mha.SelfAttention = ext.mha.SelfAttention + flash_attn.modules.mha.FlashSelfAttention = ext.mha.SelfAttention + flash_attn.modules.mha.CrossAttention = ext.mha.CrossAttention + flash_attn.modules.mha.FlashCrossAttention = ext.mha.CrossAttention def _patch_ops(): import deeplink_ext.internlm_ops as ext diff --git a/tests/test_mha_internlm.py b/tests/test_mha_internlm.py index b74ecc47..735be33b 100644 --- a/tests/test_mha_internlm.py +++ b/tests/test_mha_internlm.py @@ -29,7 +29,7 @@ def _run_cross_attention( D = 8 qkv = torch.randn(B, S, 3, H, D, dtype=torch.float16).cuda() output_gold, grad_gold = _run_self_attention(ext.fallback.SelfAttention, qkv) -output_ext, grad_ext = _run_self_attention(ext.DeepLinkSelfAttention, qkv) +output_ext, grad_ext = _run_self_attention(ext.SelfAttention, qkv) assert torch.allclose(output_gold, output_ext, atol=1e-3) print("SelfAttention forward test pass") assert torch.allclose(grad_gold, grad_ext, atol=2e-3) @@ -40,7 +40,7 @@ def _run_cross_attention( output_gold, dq_gold, dkv_gold = _run_cross_attention( ext.fallback.CrossAttention, q, kv ) -output_ext, dq_ext, dkv_ext = _run_cross_attention(ext.DeepLinkCrossAttention, q, kv) +output_ext, dq_ext, dkv_ext = _run_cross_attention(ext.CrossAttention, q, kv) assert torch.allclose(output_gold, output_ext, atol=1e-3) print("CrossAttention forward test pass") assert torch.allclose(dq_gold, dq_ext, atol=2e-3) From 6d798aa6a0d7306b4a2ee63d9ed6125ebea25cf5 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 1 Apr 2024 07:26:09 +0000 Subject: [PATCH 38/39] rename rotary_embedding --- ...rotary_embeddinig_fallback.py => rotary_embedding_fallback.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename deeplink_ext/internlm_ops/{rotary_embeddinig_fallback.py => rotary_embedding_fallback.py} (100%) diff --git a/deeplink_ext/internlm_ops/rotary_embeddinig_fallback.py b/deeplink_ext/internlm_ops/rotary_embedding_fallback.py similarity index 100% rename from deeplink_ext/internlm_ops/rotary_embeddinig_fallback.py rename to deeplink_ext/internlm_ops/rotary_embedding_fallback.py From 3380cfb361853c56c20e157a88315abbc5751c40 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 1 Apr 2024 07:30:26 +0000 Subject: [PATCH 39/39] lint --- deeplink_ext/patch_internlm.py | 4 +--- tests/test_rms_internlm.py | 5 ++++- tests/test_rotary_emb_internlm.py | 12 +++++------- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/deeplink_ext/patch_internlm.py b/deeplink_ext/patch_internlm.py index 2da73528..17bc373f 100644 --- a/deeplink_ext/patch_internlm.py +++ b/deeplink_ext/patch_internlm.py @@ -129,9 +129,7 @@ def backward(ctx, dqkv: torch.Tensor, *args, **kwargs): # type: ignore else (class_or_tuple,) ) ) - and isinstance_orig( - obj, ext.rms_norm.RMSNormWithNormalizedShape - ) + and isinstance_orig(obj, ext.rms_norm.RMSNormWithNormalizedShape) ) ) diff --git a/tests/test_rms_internlm.py b/tests/test_rms_internlm.py index 5ff0dd9b..dbb02667 100644 --- a/tests/test_rms_internlm.py +++ b/tests/test_rms_internlm.py @@ -3,7 +3,10 @@ import torch import numpy as np from deeplink_ext.internlm_ops.rms_norm import RMSNorm, RMSNormWithNormalizedShape -from deeplink_ext.internlm_ops.rms_norm_fallback import RMSNorm as RMSNorm_fb, RMSNormWithNormalizedShape as RMSNormWithNormalizedShape_fb +from deeplink_ext.internlm_ops.rms_norm_fallback import ( + RMSNorm as RMSNorm_fb, + RMSNormWithNormalizedShape as RMSNormWithNormalizedShape_fb, +) def rms_norm_test(BaseRmsNorm, DeeplinkRmsNorm, rtol=1e-4, atol=1e-3): diff --git a/tests/test_rotary_emb_internlm.py b/tests/test_rotary_emb_internlm.py index be08c22b..70e93d05 100644 --- a/tests/test_rotary_emb_internlm.py +++ b/tests/test_rotary_emb_internlm.py @@ -2,7 +2,9 @@ import torch from deeplink_ext.internlm_ops.rotary_embedding import apply_rotary -from deeplink_ext.internlm_ops.rotary_embeddinig_fallback import apply_rotary as apply_rotary_fb +from deeplink_ext.internlm_ops.rotary_embeddinig_fallback import ( + apply_rotary as apply_rotary_fb, +) def RotaryEmbTestFloat16() -> bool: @@ -14,9 +16,7 @@ def RotaryEmbTestFloat16() -> bool: inplace = True interleaved = False - res1 = apply_rotary_fb( - input, cos, sin, interleaved=interleaved, inplace=inplace - ) + res1 = apply_rotary_fb(input, cos, sin, interleaved=interleaved, inplace=inplace) res2 = apply_rotary(input1, cos, sin, interleaved=interleaved, inplace=inplace) # there is a little calculated error with ascend when dtype is float16 @@ -32,9 +32,7 @@ def RotaryEmbTestFloat32() -> bool: inplace = True interleaved = False - res1 = apply_rotary_fb( - input, cos, sin, interleaved=interleaved, inplace=inplace - ) + res1 = apply_rotary_fb(input, cos, sin, interleaved=interleaved, inplace=inplace) res2 = apply_rotary(input1, cos, sin, interleaved=interleaved, inplace=inplace) return torch.allclose(res1, res2)