Skip to content

Commit

Permalink
[luci/service] Support RmsNorm operation (Samsung#14025)
Browse files Browse the repository at this point in the history
This commit supports RmsNorm operation in luci service.

ONE-DCO-1.0-Signed-off-by: Seockho Kim [email protected]
  • Loading branch information
seockho-kim authored Sep 18, 2024
1 parent 804fa33 commit 10fbedb
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 0 deletions.
1 change: 1 addition & 0 deletions compiler/luci/service/src/CircleCloneNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ class CloneNode final : public luci::CircleNodeVisitor<luci::CircleNode *>
luci::CircleNode *visit(const luci::CircleBCQGather *) final;
luci::CircleNode *visit(const luci::CircleInstanceNorm *) final;
luci::CircleNode *visit(const luci::CircleGRU *) final;
luci::CircleNode *visit(const luci::CircleRmsNorm *) final;

// NOTE CircleInput and CircleOutput are not handled here as these need
// link with graph I/O
Expand Down
7 changes: 7 additions & 0 deletions compiler/luci/service/src/CircleShapeInferenceRule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2198,6 +2198,13 @@ class ShapeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::NodeS

loco::NodeShape visit(const luci::CircleGRU *node) final { return infer_circle_gru(node); }

loco::NodeShape visit(const luci::CircleRmsNorm *node) final
{
auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();

return loco::NodeShape{input_shape};
}

// Virtual
loco::NodeShape visit(const luci::CircleInput *node) final { return infer_input(node); }

Expand Down
5 changes: 5 additions & 0 deletions compiler/luci/service/src/CircleTypeInferenceRule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,11 @@ struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::DataT
return luci::dtype_get(node->input());
}

loco::DataType visit(const luci::CircleRmsNorm *node) final
{
return luci::dtype_get(node->input());
}

// Virtual
loco::DataType visit(const luci::CircleInput *node) final { return node->dtype(); }

Expand Down
32 changes: 32 additions & 0 deletions compiler/luci/service/src/Nodes/CircleRmsNorm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "CircleCloneNode.h"

namespace luci
{

luci::CircleNode *CloneNode::visit(const luci::CircleRmsNorm *node)
{
auto *cloned = _graph->nodes()->create<luci::CircleRmsNorm>();
if (cloned != nullptr)
{
cloned->epsilon(node->epsilon());
}
return cloned;
}

} // namespace luci
35 changes: 35 additions & 0 deletions compiler/luci/service/src/Nodes/CircleRmsNorm.test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "luci/Service/CircleNodeClone.h"

#include <gtest/gtest.h>

TEST(CloneNodeTest, clone_RmsNorm)
{
auto g = loco::make_graph();
auto node_fc = g->nodes()->create<luci::CircleRmsNorm>();
node_fc->epsilon(3);

auto gc = loco::make_graph();
auto cloned = luci::clone_node(node_fc, gc.get());
ASSERT_NE(nullptr, cloned);
ASSERT_EQ(gc.get(), cloned->graph());

auto cloned_fc = dynamic_cast<luci::CircleRmsNorm *>(cloned);
ASSERT_NE(nullptr, cloned_fc);
ASSERT_EQ(node_fc->epsilon(), cloned_fc->epsilon());
}

0 comments on commit 10fbedb

Please sign in to comment.