From 3f100c019197a4023f12c213cbb6af985afd2211 Mon Sep 17 00:00:00 2001 From: zhangzefeng92 Date: Wed, 27 Mar 2024 10:11:53 +0800 Subject: [PATCH] modify extensions.cpp --- csrc/extensions.cpp | 43 +++++++++++++------------------------------ 1 file changed, 13 insertions(+), 30 deletions(-) diff --git a/csrc/extensions.cpp b/csrc/extensions.cpp index 79e02ad5..93e87430 100644 --- a/csrc/extensions.cpp +++ b/csrc/extensions.cpp @@ -26,46 +26,29 @@ namespace dipu::dipu_ext { - -auto extRmsNorm(at::Tensor& output, - at::Tensor& inv_rms, +auto extRmsNorm(at::Tensor& output, at::Tensor& inv_rms, const at::Tensor& input, const OptionalIntArray& normalized_shape, - const at::Tensor& weight, - const at::Tensor& bias, - double eps) { + const at::Tensor& weight, const at::Tensor& bias, double eps) { at::OptionalIntArrayRef normalized_shape_at = *normalized_shape; - callDiopi(diopiRMSNorm, output, inv_rms, input, normalized_shape_at, weight, bias, eps); + callDiopi(diopiRMSNorm, output, inv_rms, input, normalized_shape_at, weight, + bias, eps); return std::make_tuple(std::move(output), std::move(inv_rms)); } - -auto extRmsNormBackward(at::Tensor& grad_input, - at::Tensor& grad_weight, - at::Tensor& grad_bias, - const at::Tensor& grad_output, - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& bias, - const at::Tensor& inv_rms, - const OptionalIntArray& normalized_shape, - double eps) { +auto extRmsNormBackward(at::Tensor& grad_input, at::Tensor& grad_weight, + at::Tensor& grad_bias, const at::Tensor& grad_output, + const at::Tensor& input, const at::Tensor& weight, + const at::Tensor& bias, const at::Tensor& inv_rms, + const OptionalIntArray& normalized_shape, double eps) { at::OptionalIntArrayRef normalized_shape_at = *normalized_shape; - callDiopi(diopiRMSNormBackward, - grad_input, - grad_weight, - grad_bias, - grad_output, - input, - weight, - bias, - inv_rms, - normalized_shape_at, + callDiopi(diopiRMSNormBackward, grad_input, grad_weight, grad_bias, + grad_output, input, weight, bias, inv_rms, normalized_shape_at, eps); - return std::make_tuple(std::move(grad_input), std::move(grad_weight), std::move(grad_bias)); + return std::make_tuple(std::move(grad_input), std::move(grad_weight), + std::move(grad_bias)); } - void extApplyRotary(at::Tensor output, const at::Tensor& input, const at::Tensor& cos, const at::Tensor& sin, const bool conj, const bool interleaved) {