Skip to content

Commit

Permalink
modify extensions.cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangzefeng92 committed Mar 27, 2024
1 parent d675c9f commit 3f100c0
Showing 1 changed file with 13 additions and 30 deletions.
43 changes: 13 additions & 30 deletions csrc/extensions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,46 +26,29 @@

namespace dipu::dipu_ext {


auto extRmsNorm(at::Tensor& output,
at::Tensor& inv_rms,
auto extRmsNorm(at::Tensor& output, at::Tensor& inv_rms,
const at::Tensor& input,
const OptionalIntArray& normalized_shape,
const at::Tensor& weight,
const at::Tensor& bias,
double eps) {
const at::Tensor& weight, const at::Tensor& bias, double eps) {
at::OptionalIntArrayRef normalized_shape_at = *normalized_shape;
callDiopi(diopiRMSNorm, output, inv_rms, input, normalized_shape_at, weight, bias, eps);
callDiopi(diopiRMSNorm, output, inv_rms, input, normalized_shape_at, weight,
bias, eps);
return std::make_tuple(std::move(output), std::move(inv_rms));
}


auto extRmsNormBackward(at::Tensor& grad_input,
at::Tensor& grad_weight,
at::Tensor& grad_bias,
const at::Tensor& grad_output,
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& bias,
const at::Tensor& inv_rms,
const OptionalIntArray& normalized_shape,
double eps) {
auto extRmsNormBackward(at::Tensor& grad_input, at::Tensor& grad_weight,
at::Tensor& grad_bias, const at::Tensor& grad_output,
const at::Tensor& input, const at::Tensor& weight,
const at::Tensor& bias, const at::Tensor& inv_rms,
const OptionalIntArray& normalized_shape, double eps) {
at::OptionalIntArrayRef normalized_shape_at = *normalized_shape;
callDiopi(diopiRMSNormBackward,
grad_input,
grad_weight,
grad_bias,
grad_output,
input,
weight,
bias,
inv_rms,
normalized_shape_at,
callDiopi(diopiRMSNormBackward, grad_input, grad_weight, grad_bias,
grad_output, input, weight, bias, inv_rms, normalized_shape_at,
eps);
return std::make_tuple(std::move(grad_input), std::move(grad_weight), std::move(grad_bias));
return std::make_tuple(std::move(grad_input), std::move(grad_weight),
std::move(grad_bias));
}


void extApplyRotary(at::Tensor output, const at::Tensor& input,
const at::Tensor& cos, const at::Tensor& sin,
const bool conj, const bool interleaved) {
Expand Down

0 comments on commit 3f100c0

Please sign in to comment.