Skip to content

Commit

Permalink
[luci/lang] Add CircleGRU node (Samsung#12598)
Browse files Browse the repository at this point in the history
This pr adds CircleGRU node in luci/lang

ONE-DCO-1.0-Signed-off-by: Artem Balyshev <[email protected]>
  • Loading branch information
BalyshevArtem authored Mar 13, 2024
1 parent afd538e commit 6765b4e
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 0 deletions.
1 change: 1 addition & 0 deletions compiler/luci/lang/include/luci/IR/CircleNodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@
// Circle only
#include "Nodes/CircleBCQFullyConnected.h"
#include "Nodes/CircleBCQGather.h"
#include "Nodes/CircleGRU.h"
#include "Nodes/CircleInstanceNorm.h"
// Virtual nodes
#include "Nodes/CircleConst.h"
Expand Down
1 change: 1 addition & 0 deletions compiler/luci/lang/include/luci/IR/CircleNodes.lst
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ CIRCLE_NODE(GATHER_ND, CircleGatherNd)
CIRCLE_NODE(GELU, CircleGelu)
CIRCLE_NODE(GREATER, CircleGreater)
CIRCLE_NODE(GREATER_EQUAL, CircleGreaterEqual)
CIRCLE_NODE(GRU, CircleGRU)
CIRCLE_NODE(HARD_SWISH, CircleHardSwish)
CIRCLE_NODE(IF, CircleIf)
CIRCLE_NODE(L2_NORMALIZATION, CircleL2Normalize)
Expand Down
70 changes: 70 additions & 0 deletions compiler/luci/lang/include/luci/IR/Nodes/CircleGRU.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef __LUCI_IR_CIRCLE_GRU_H__
#define __LUCI_IR_CIRCLE_GRU_H__

#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"

#include "luci/IR/CircleNodeMixins.h"

namespace luci
{

/**
* @brief GRU in Circle
*/
class CircleGRU final : public FixedArityNode<6, CircleNodeImpl<CircleOpcode::GRU>>
{
public:
loco::Node *input(void) const { return at(0)->node(); }
void input(loco::Node *node) { at(0)->node(node); }

loco::Node *hidden_hidden(void) const { return at(1)->node(); }
void hidden_hidden(loco::Node *node) { at(1)->node(node); }

loco::Node *hidden_hidden_bias(void) const { return at(2)->node(); }
void hidden_hidden_bias(loco::Node *node) { at(2)->node(node); }

loco::Node *hidden_input(void) const { return at(3)->node(); }
void hidden_input(loco::Node *node) { at(3)->node(node); }

loco::Node *hidden_input_bias(void) const { return at(4)->node(); }
void hidden_input_bias(loco::Node *node) { at(4)->node(node); }

loco::Node *state(void) const { return at(5)->node(); }
void state(loco::Node *node) { at(5)->node(node); }

public:
FusedActFunc fusedActivationFunction() const { return _fused_act_fun; }
void fusedActivationFunction(FusedActFunc fused_act_fun) { _fused_act_fun = fused_act_fun; }

bool returnSequences() const { return _return_sequences; }
void returnSequences(bool return_sequences) { _return_sequences = return_sequences; }

bool timeMajor() const { return _time_major; }
void timeMajor(bool time_major) { _time_major = time_major; }

private:
FusedActFunc _fused_act_fun = FusedActFunc::NONE;
bool _return_sequences = false;
bool _time_major = false;
};

} // namespace luci

#endif // __LUCI_IR_CIRCLE_GRU_H__
86 changes: 86 additions & 0 deletions compiler/luci/lang/src/Nodes/CircleGRU.test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "luci/IR/Nodes/CircleGRU.h"

#include "luci/IR/CircleDialect.h"
#include "luci/IR/CircleNodeVisitor.h"

#include <gtest/gtest.h>

TEST(CircleGRUTest, constructor)
{
luci::CircleGRU gru_node;

ASSERT_EQ(luci::CircleDialect::get(), gru_node.dialect());
ASSERT_EQ(luci::CircleOpcode::GRU, gru_node.opcode());

ASSERT_EQ(nullptr, gru_node.input());
ASSERT_EQ(nullptr, gru_node.hidden_hidden());
ASSERT_EQ(nullptr, gru_node.hidden_hidden_bias());
ASSERT_EQ(nullptr, gru_node.hidden_input());
ASSERT_EQ(nullptr, gru_node.hidden_input_bias());
ASSERT_EQ(nullptr, gru_node.state());
}

TEST(CircleGRUTest, input_NEG)
{
luci::CircleGRU gru_node;
luci::CircleGRU node;

gru_node.input(&node);
ASSERT_NE(nullptr, gru_node.input());

gru_node.input(nullptr);
ASSERT_EQ(nullptr, gru_node.input());
}

TEST(CircleGRUTest, arity_NEG)
{
luci::CircleGRU gru_node;

ASSERT_NO_THROW(gru_node.arg(0));
ASSERT_NO_THROW(gru_node.arg(1));
ASSERT_NO_THROW(gru_node.arg(2));
ASSERT_NO_THROW(gru_node.arg(3));
ASSERT_NO_THROW(gru_node.arg(4));
ASSERT_NO_THROW(gru_node.arg(5));
ASSERT_THROW(gru_node.arg(6), std::out_of_range);
}

TEST(CircleGRUTest, visit_mutable_NEG)
{
struct TestVisitor final : public luci::CircleNodeMutableVisitor<void>
{
};

luci::CircleGRU gru_node;

TestVisitor tv;
ASSERT_THROW(gru_node.accept(&tv), std::exception);
}

TEST(CircleGRUTest, visit_NEG)
{
struct TestVisitor final : public luci::CircleNodeVisitor<void>
{
};

luci::CircleGRU gru_node;

TestVisitor tv;
ASSERT_THROW(gru_node.accept(&tv), std::exception);
}

0 comments on commit 6765b4e

Please sign in to comment.