Skip to content

Commit

Permalink
add the rms norm
Browse files Browse the repository at this point in the history
  • Loading branch information
wusar committed Aug 15, 2023
1 parent faa7b3b commit b9217a4
Show file tree
Hide file tree
Showing 9 changed files with 199 additions and 3 deletions.
5 changes: 4 additions & 1 deletion ark/include/ark.h
Original file line number Diff line number Diff line change
Expand Up @@ -337,9 +337,12 @@ class Model
const std::string &name = "reduce_max");
// Applies layer normalization to the `input` tensor and returns the
// normalized tensor as `output`.

Tensor *layernorm(Tensor *input, Tensor *output = nullptr,
const std::string &name = "layernorm");
// Applies RMS (Root Mean Square Layer Normalization) normalization to the
// `input` tensor and returns the normalized tensor as `output`.
Tensor *rmsnorm(Tensor *input, Tensor *output = nullptr,
const std::string &name = "rmsnorm");
// Applies softmax activation to the `input` tensor, with the softmax
// operator
// being performed on the last dimension of the input tensor.
Expand Down
89 changes: 89 additions & 0 deletions ark/include/kernels/layernorm.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,95 @@ DEVICE void layernorm(ark::half *out, const ark::half *in, int uop_idx,
smem_per_warp);
}

// Perform RMS normalization on input and write the result on output.
// Root Mean Square Layer Normalization: https://arxiv.org/pdf/1910.07467.pdf
template <typename InDims, typename InShape, typename OutDims,
typename OutShape, typename UnitOutDims, int NumThreads,
int SmemBytes, typename DataType, int NelemPerThread>
struct RMSNorm
{
using UnitOp =
UnitOp<OutDims, OutShape, UnitOutDims, NumThreads, SmemBytes>;

static_assert(NelemPerThread > 0, "NelemPerThread must be positive");
static DEVICE void run(DataType *out, const DataType *in, int uop_idx,
int smem_per_warp)
{
using InOutChk = LayerNormShapeChecker<InShape, OutShape>;
using ReduceTypeMean = ReduceTypeMean<DataType, NelemPerThread>;

constexpr int NonReduceDimLength = UnitOutDims::NCH;
// The reduction dimension of the final stage.
// Assume this division is always exact.
static_assert((NumThreads * NelemPerThread) % NonReduceDimLength == 0);
// If we reshape the input into a 2D matrix (NCH x W), NumThreads
// threads compute NCH rows, and each row's sum is computed by
// ThreadsPerRow threads. If ThreadsPerRow is larger than warp size, we
// need to use shared memory to reduce the result of each warp.
constexpr int ThreadsPerRow =
(NumThreads * NelemPerThread) / NonReduceDimLength;

int tid = UnitOp::thread_id();
int tid_w = (tid * NelemPerThread) % ThreadsPerRow;
int tid_h = ((tid * NelemPerThread) / ThreadsPerRow) % UnitOutDims::H;
int tid_c = ((tid * NelemPerThread) / ThreadsPerRow / UnitOutDims::H) %
UnitOutDims::C;
int tid_n = (tid * NelemPerThread) / ThreadsPerRow / UnitOutDims::CH;

int un = UnitOp::uop_idx_n(uop_idx);
int uc = UnitOp::uop_idx_c(uop_idx);
int uh = UnitOp::uop_idx_h(uop_idx);

int idx_in_base = (tid_h + uh * UnitOutDims::H) * InDims::W +
(tid_c + uc * UnitOutDims::C) * InDims::HW +
(tid_n + un * UnitOutDims::N) * InDims::CHW;

DataType variance;
ReduceTypeMean::singleIdentity(&variance);
// get the variance
UnitOp::sync_threads();
for (int idx_in_w = tid_w; idx_in_w < InShape::W;
idx_in_w += ThreadsPerRow) {
int idx_in = idx_in_base + idx_in_w;
variance += (in[idx_in]) * (in[idx_in]);
}
UnitOp::sync_threads();
variance = warpsReduce<ReduceTypeMean, UnitOp, ThreadsPerRow>(
variance, tid, smem_per_warp);
ReduceTypeMean::singlePostReduce(&variance, &variance, UnitOutDims::W);
UnitOp::sync_threads();
// the output is (input - mean) / sqrt(variance)
for (int idx_in_w = tid_w; idx_in_w < InShape::W;
idx_in_w += ThreadsPerRow) {
int idx_in = idx_in_base + idx_in_w;
out[idx_in] = (in[idx_in]) * rsqrtf(variance + 1e-5f);
}
}
};

template <typename InDims, typename InShape, typename OutDims,
typename OutShape, typename UnitOutDims, int NumThreads,
int SmemBytes>
DEVICE void rmsnorm(float *out, const float *in, int uop_idx, int smem_per_warp)
{
constexpr int NelemPerThread = 1;
RMSNorm<InDims, InShape, OutDims, OutShape, UnitOutDims, NumThreads,
SmemBytes, float, NelemPerThread>::run(out, in, uop_idx,
smem_per_warp);
}

template <typename InDims, typename InShape, typename OutDims,
typename OutShape, typename UnitOutDims, int NumThreads,
int SmemBytes>
DEVICE void rmsnorm(ark::half *out, const ark::half *in, int uop_idx,
int smem_per_warp)
{
constexpr int NelemPerThread = 1;
RMSNorm<InDims, InShape, OutDims, OutShape, UnitOutDims, NumThreads,
SmemBytes, ark::half, NelemPerThread>::run(out, in, uop_idx,
smem_per_warp);
}

} // namespace ark

#endif // ARK_KERNELS_LAYERNORM_H_
3 changes: 3 additions & 0 deletions ark/ops/ops_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ ostream &operator<<(ostream &os, const OpType &s)
case OP_RECV: os << "OP_RECV"; break;
case OP_RECV_MM: os << "OP_RECV_MM"; break;
case OP_LAYERNORM: os << "OP_LAYERNORM"; break;
case OP_RMSNORM: os << "OP_RMSNORM"; break;
case OP_SOFTMAX: os << "OP_SOFTMAX"; break;
case OP_RELU: os << "OP_RELU"; break;
case OP_SIGMOID: os << "OP_SIGMOID"; break;
Expand Down Expand Up @@ -456,6 +457,8 @@ std::string Op::function_name(const OpConfig &cfg) const
return static_cast<const RecvMMOp *>(this)->function_name(cfg);
case OP_LAYERNORM:
return static_cast<const LayernormOp *>(this)->function_name(cfg);
case OP_RMSNORM:
return static_cast<const RMSnormOp *>(this)->function_name(cfg);
case OP_SOFTMAX:
return static_cast<const SoftmaxOp *>(this)->function_name(cfg);
case OP_RELU:
Expand Down
9 changes: 9 additions & 0 deletions ark/ops/ops_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ typedef enum
OP_REDUCE_W_MEAN,
OP_REDUCE_W_MAX,
OP_LAYERNORM,
OP_RMSNORM,
OP_SOFTMAX,
OP_SCALE,
OP_RELU,
Expand Down Expand Up @@ -343,6 +344,14 @@ class LayernormOp : public Op
std::string function_name(const OpConfig &cfg) const;
};

class RMSnormOp : public Op
{
public:
RMSnormOp(OpPrecType prec_type, Tensor *input, Tensor *output,
const std::string &name);
std::string function_name(const OpConfig &cfg) const;
};

class MatmulOp : public Op
{
public:
Expand Down
70 changes: 70 additions & 0 deletions ark/ops/ops_rmsnorm.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

#include "logging.h"
#include "model.h"
#include <cassert>

namespace ark {

extern const OpConfigMap LayernormConfigMap;

RMSnormOp::RMSnormOp(OpPrecType prec_type, Tensor *input, Tensor *output,
const std::string &name)
: Op{OP_RMSNORM, prec_type, {input}, {output}, {},
name, &LayernormConfigMap, -1}
{
}

std::string RMSnormOp::function_name(const OpConfig &cfg) const
{
Tensor *input = this->inputs[0];
Tensor *output = this->outputs[0];

int ndims = output->shape.ndims();
const OpTile &tile_out = cfg.output_tiles[0];
CHECK(output->ldims[ndims - 1] % tile_out.y == 0);
if (ndims > 1) {
CHECK(output->ldims[ndims - 2] % tile_out.x == 0);
} else {
CHECK(tile_out.x == 1);
}

Dims unit_out_dims{1, 1, tile_out.x, tile_out.y};
return Op::function_name("ark::rmsnorm",
{{
input->ldims.dims4(), // InDims
input->shape.dims4(), // InShape
output->ldims.dims4(), // OutDims
output->shape.dims4(), // OutShape
unit_out_dims, // UnitOutDims
cfg.num_warps * 32, // NumThreads
cfg.smem_bytes, // SmemBytes
}});
}

Tensor *Model::rmsnorm(Tensor *input, Tensor *output, const std::string &name)
{
assert(input != nullptr);
LOG(DEBUG, "rmsnorm ", input->shape, " ", input->ldims, " ");
OpPrecType pt;
if (input->type == FP16) {
pt = OP_PREC_FP16;
} else if (input->type == FP32) {
pt = OP_PREC_FP32;
} else {
LOG(ERROR, "unsupported input data type: ", input->type);
}
if (output != nullptr && input->type != output->type) {
LOG(ERROR, "invalid output data type: ", output->type);
}
if (output == nullptr) {
output = this->tensor(input->shape, input->type);
} else if (output == input) {
output = this->identity(output);
}
RMSnormOp op{pt, input, output, name};
return this->impl->add_op(op)[0];
};

} // namespace ark
4 changes: 2 additions & 2 deletions examples/llama/llama_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,8 +375,8 @@ def test_rotary_embedding():
if __name__ == "__main__":
torch.distributed.init_process_group("nccl")
initialize_model_parallel(1)
# test_rmsnorm()
test_attention()
test_rmsnorm()
# test_attention()
# test_feedforward()
# test_transformerblock()
# test_transformer()
Expand Down
2 changes: 2 additions & 0 deletions python/ark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
reduce_mean,
reduce_max,
layernorm,
rmsnorm,
softmax,
transpose,
matmul,
Expand Down Expand Up @@ -96,6 +97,7 @@
"reduce_mean",
"reduce_max",
"layernorm",
"rmsnorm",
"softmax",
"transpose",
"matmul",
Expand Down
15 changes: 15 additions & 0 deletions python/ark/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,21 @@ def layernorm(
_tensor = Model.get_global_model().layernorm(input._tensor, output, name)
return Tensor(_tensor)

def rmsnorm(
input: Tensor,
output: Tensor = None,
name: str = "rmsnorm",
) -> Tensor:
"""
Applies layer normalization to the `input` tensor and returns
the normalized tensor as `output`.
Usage:
tensor_rmsnorm = ark.rmsnorm(tensor)
"""
if output is not None:
output = output._tensor
_tensor = Model.get_global_model().rmsnorm(input._tensor, output, name)
return Tensor(_tensor)

def softmax(
input: Tensor,
Expand Down
5 changes: 5 additions & 0 deletions python/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,11 @@ PYBIND11_MODULE(_ark_core, m)
"the normalized tensor as `output`.",
py::return_value_policy::reference_internal, py::arg("input"),
py::arg("output") = nullptr, py::arg("name") = "layernorm")
.def("rmsnorm", &ark::Model::rmsnorm,
"Applies RMS (Root Mean Square Layer Normalization) normalization to the `input` tensor and returns "
"the normalized tensor as `output`.",
py::return_value_policy::reference_internal, py::arg("input"),
py::arg("output") = nullptr, py::arg("name") = "rmsnorm")
.def("softmax", &ark::Model::softmax,
"Applies softmax activation to the `input` tensor, with the "
"softmax operator being performed on the last dimension of the "
Expand Down

0 comments on commit b9217a4

Please sign in to comment.