From 10fbedbe1621285695c0ee8776352dda482dca90 Mon Sep 17 00:00:00 2001 From: seockho-kim Date: Thu, 19 Sep 2024 06:12:51 +0900 Subject: [PATCH] [luci/service] Support RmsNorm operation (#14025) This commit supports RmsNorm operation in luci service. ONE-DCO-1.0-Signed-off-by: Seockho Kim seockho.kim@samsung.com --- compiler/luci/service/src/CircleCloneNode.h | 1 + .../service/src/CircleShapeInferenceRule.cpp | 7 ++++ .../service/src/CircleTypeInferenceRule.cpp | 5 +++ .../luci/service/src/Nodes/CircleRmsNorm.cpp | 32 +++++++++++++++++ .../service/src/Nodes/CircleRmsNorm.test.cpp | 35 +++++++++++++++++++ 5 files changed, 80 insertions(+) create mode 100644 compiler/luci/service/src/Nodes/CircleRmsNorm.cpp create mode 100644 compiler/luci/service/src/Nodes/CircleRmsNorm.test.cpp diff --git a/compiler/luci/service/src/CircleCloneNode.h b/compiler/luci/service/src/CircleCloneNode.h index e2f61e1eb0e..64c9e4f486f 100644 --- a/compiler/luci/service/src/CircleCloneNode.h +++ b/compiler/luci/service/src/CircleCloneNode.h @@ -259,6 +259,7 @@ class CloneNode final : public luci::CircleNodeVisitor luci::CircleNode *visit(const luci::CircleBCQGather *) final; luci::CircleNode *visit(const luci::CircleInstanceNorm *) final; luci::CircleNode *visit(const luci::CircleGRU *) final; + luci::CircleNode *visit(const luci::CircleRmsNorm *) final; // NOTE CircleInput and CircleOutput are not handled here as these need // link with graph I/O diff --git a/compiler/luci/service/src/CircleShapeInferenceRule.cpp b/compiler/luci/service/src/CircleShapeInferenceRule.cpp index 42c45353361..a094b681d0c 100644 --- a/compiler/luci/service/src/CircleShapeInferenceRule.cpp +++ b/compiler/luci/service/src/CircleShapeInferenceRule.cpp @@ -2198,6 +2198,13 @@ class ShapeInferenceAlgorithm final : public luci::CircleNodeVisitorinput()).as(); + + return loco::NodeShape{input_shape}; + } + // Virtual loco::NodeShape visit(const luci::CircleInput *node) final { return infer_input(node); } diff --git a/compiler/luci/service/src/CircleTypeInferenceRule.cpp b/compiler/luci/service/src/CircleTypeInferenceRule.cpp index 78dde1004b5..6b656567071 100644 --- a/compiler/luci/service/src/CircleTypeInferenceRule.cpp +++ b/compiler/luci/service/src/CircleTypeInferenceRule.cpp @@ -579,6 +579,11 @@ struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitorinput()); } + loco::DataType visit(const luci::CircleRmsNorm *node) final + { + return luci::dtype_get(node->input()); + } + // Virtual loco::DataType visit(const luci::CircleInput *node) final { return node->dtype(); } diff --git a/compiler/luci/service/src/Nodes/CircleRmsNorm.cpp b/compiler/luci/service/src/Nodes/CircleRmsNorm.cpp new file mode 100644 index 00000000000..0fdf2bdf3d8 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleRmsNorm.cpp @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CircleCloneNode.h" + +namespace luci +{ + +luci::CircleNode *CloneNode::visit(const luci::CircleRmsNorm *node) +{ + auto *cloned = _graph->nodes()->create(); + if (cloned != nullptr) + { + cloned->epsilon(node->epsilon()); + } + return cloned; +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleRmsNorm.test.cpp b/compiler/luci/service/src/Nodes/CircleRmsNorm.test.cpp new file mode 100644 index 00000000000..9bd0bc891da --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleRmsNorm.test.cpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleNodeClone.h" + +#include + +TEST(CloneNodeTest, clone_RmsNorm) +{ + auto g = loco::make_graph(); + auto node_fc = g->nodes()->create(); + node_fc->epsilon(3); + + auto gc = loco::make_graph(); + auto cloned = luci::clone_node(node_fc, gc.get()); + ASSERT_NE(nullptr, cloned); + ASSERT_EQ(gc.get(), cloned->graph()); + + auto cloned_fc = dynamic_cast(cloned); + ASSERT_NE(nullptr, cloned_fc); + ASSERT_EQ(node_fc->epsilon(), cloned_fc->epsilon()); +}