Skip to content

Commit

Permalink
support ascend rms_norm
Browse files Browse the repository at this point in the history
  • Loading branch information
jingguo-st authored Jan 18, 2024
1 parent a39abfd commit b1cced5
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions ext_op/example_ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ auto extRmsNorm(const at::Tensor& input,
at::OptionalIntArrayRef normalized_shape_at =
optionalIntArrayToIntArrayRefOrDefault(normalized_shape, weight.sizes());
auto input_shape = input.sizes();
std::vector<int64_t> input_size(input_shape.size(), 1);
std::copy(input_shape.begin(), input_shape.end() - 1, input_size.begin());
std::vector<int64_t> input_size(input_shape.begin(), input_shape.end());
input_size.back() = 1;
auto inv_rms = at::empty(input_size, input.options());
auto output = at::empty_like(input);
callDiopi(diopiRMSNorm, output, inv_rms, input, normalized_shape_at, weight,
Expand Down

0 comments on commit b1cced5

Please sign in to comment.