diff --git a/compiler/luci/import/include/luci/Import/Nodes.h b/compiler/luci/import/include/luci/Import/Nodes.h index f3f4871b469..6fcd5d975d0 100644 --- a/compiler/luci/import/include/luci/Import/Nodes.h +++ b/compiler/luci/import/include/luci/Import/Nodes.h @@ -107,6 +107,7 @@ #include "Nodes/CircleResizeNearestNeighbor.h" #include "Nodes/CircleReverseSequence.h" #include "Nodes/CircleReverseV2.h" +#include "Nodes/CircleRmsNorm.h" #include "Nodes/CircleRound.h" #include "Nodes/CircleRsqrt.h" #include "Nodes/CircleScatterNd.h" diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleRmsNorm.h b/compiler/luci/import/include/luci/Import/Nodes/CircleRmsNorm.h new file mode 100644 index 00000000000..a2ebcdf657b --- /dev/null +++ b/compiler/luci/import/include/luci/Import/Nodes/CircleRmsNorm.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_IMPORT_OP_CIRCLE_RMS_NORM_H__ +#define __LUCI_IMPORT_OP_CIRCLE_RMS_NORM_H__ + +#include "luci/Import/GraphBuilder.h" + +namespace luci +{ + +class CircleRmsNormGraphBuilder : public GraphBuilder +{ +public: + bool validate(const ValidateArgs &args) const final; + +private: + CircleNode *build_node(const circle::OperatorT &op, const std::vector &inputs, + loco::Graph *graph) const final; +}; + +} // namespace luci + +#endif // __LUCI_IMPORT_OP_CIRCLE_RMS_NORM_H__ diff --git a/compiler/luci/import/src/GraphBuilderRegistry.cpp b/compiler/luci/import/src/GraphBuilderRegistry.cpp index 29edf8348f3..1e2e8837029 100644 --- a/compiler/luci/import/src/GraphBuilderRegistry.cpp +++ b/compiler/luci/import/src/GraphBuilderRegistry.cpp @@ -116,6 +116,7 @@ GraphBuilderRegistry::GraphBuilderRegistry() CIRCLE_NODE(RESIZE_NEAREST_NEIGHBOR, CircleResizeNearestNeighborGraphBuilder); // 97 CIRCLE_NODE(REVERSE_SEQUENCE, CircleReverseSequenceGraphBuilder); // 112 CIRCLE_NODE(REVERSE_V2, CircleReverseV2GraphBuilder); // 105 + CIRCLE_NODE(RMS_NORM, CircleRmsNormGraphBuilder); // 255 CIRCLE_NODE(ROUND, CircleRoundGraphBuilder); // 116 CIRCLE_NODE(RSQRT, CircleRsqrtGraphBuilder); // 76 CIRCLE_NODE(SCATTER_ND, CircleScatterNdGraphBuilder); // 122 diff --git a/compiler/luci/import/src/Nodes/CircleRmsNorm.cpp b/compiler/luci/import/src/Nodes/CircleRmsNorm.cpp new file mode 100644 index 00000000000..28fef764a65 --- /dev/null +++ b/compiler/luci/import/src/Nodes/CircleRmsNorm.cpp @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Import/Nodes/CircleRmsNorm.h" + +#include + +#include + +namespace luci +{ + +bool CircleRmsNormGraphBuilder::validate(const ValidateArgs &args) const +{ + // TODO check dtypes + return GraphBuilder::validate(args, 3); +} + +CircleNode *CircleRmsNormGraphBuilder::build_node(const circle::OperatorT &op, + const std::vector &inputs, + loco::Graph *graph) const +{ + auto *node = graph->nodes()->create(); + node->input(inputs.at(0)); + node->gamma(inputs.at(1)); + node->beta(inputs.at(2)); + + const auto *options = op.builtin_options.AsRmsNormOptions(); + node->epsilon(options->epsilon); + + return node; +} + +} // namespace luci