Skip to content

Commit

Permalink
[luci/service] Migrate Neg shape inference rule to sinf::Algorithm (S…
Browse files Browse the repository at this point in the history
…amsung#13922)

This commit migrate Neg shape inference rule to sinf::Algorithm.

ONE-DCO-1.0-Signed-off-by: HanJin Choi [email protected]
  • Loading branch information
Hanjin-Choi authored Sep 5, 2024
1 parent b097991 commit 67df4c5
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class Algorithm final : public luci::CircleNodeVisitor<loco::TensorShape>
// loco::TensorShape visit(const luci::CircleMinimum *node) final;
// loco::TensorShape visit(const luci::CircleMirrorPad *node) final;
loco::TensorShape visit(const luci::CircleMul *node) final;
// loco::TensorShape visit(const luci::CircleNeg *node) final;
loco::TensorShape visit(const luci::CircleNeg *node) final;
// loco::TensorShape visit(const luci::CircleNonMaxSuppressionV4 *node) final;
// loco::TensorShape visit(const luci::CircleNonMaxSuppressionV5 *node) final;
// loco::TensorShape visit(const luci::CircleNotEqual *node) final;
Expand Down
2 changes: 0 additions & 2 deletions compiler/luci/service/src/CircleShapeInferenceRule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2142,8 +2142,6 @@ class ShapeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::NodeS

loco::NodeShape visit(const luci::CircleMirrorPad *node) final { return infer_mirror_pad(node); }

loco::NodeShape visit(const luci::CircleNeg *node) final { return use_x(node); }

loco::NodeShape visit(const luci::CircleNonMaxSuppressionV4 *node) final
{
const auto boxes_shape = luci::shape_get(node->boxes()).as<loco::TensorShape>();
Expand Down
14 changes: 14 additions & 0 deletions compiler/luci/service/src/Nodes/CircleNeg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "luci/Service/CircleShapeInference.h"

#include "CircleCloneNode.h"
#include "CircleShapeInferenceHelper.h"

namespace luci
{
Expand All @@ -24,4 +26,16 @@ luci::CircleNode *CloneNodeLet<CN::KLMN>::visit(const luci::CircleNeg *)
return _graph->nodes()->create<luci::CircleNeg>();
}

namespace sinf
{

loco::TensorShape Algorithm::visit(const luci::CircleNeg *node)
{
const auto input_x = loco::must_cast<CircleNode *>(node->x());
const auto input_shape = circle_shape(input_x);
return input_shape;
}

} // namespace sinf

} // namespace luci
39 changes: 39 additions & 0 deletions compiler/luci/service/src/Nodes/CircleNeg.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#include "luci/Service/CircleNodeClone.h"
#include "luci/Service/CircleShapeInference.h"

#include <gtest/gtest.h>

Expand All @@ -31,3 +32,41 @@ TEST(CloneNodeTest, clone_Neg)
auto cloned_neg = dynamic_cast<luci::CircleNeg *>(cloned);
ASSERT_NE(nullptr, cloned_neg);
}

TEST(ShapeRuleTest, Neg_dynamic_shape)
{
luci::CircleInput input;
luci::CircleNeg neg;

loco::TensorShape shape;
luci::sinf::Rule shape_inf_rule;

input.shape({1, 1, 3, 4});
input.shape_status(luci::ShapeStatus::VALID);
input.dim(1).unset();

neg.x(&input);

ASSERT_TRUE(shape_inf_rule.infer(&neg, shape));
ASSERT_EQ(shape.rank(), 4);
ASSERT_TRUE(shape.dim(0).known());
ASSERT_FALSE(shape.dim(1).known());
ASSERT_TRUE(shape.dim(2).known());
ASSERT_TRUE(shape.dim(3).known());

ASSERT_EQ(1, shape.dim(0).value());
ASSERT_EQ(0, shape.dim(1).value());
ASSERT_EQ(3, shape.dim(2).value());
ASSERT_EQ(4, shape.dim(3).value());
}

TEST(ShapeRuleTest, Neg_nullptr_input_NEG)
{
luci::CircleNeg neg;

loco::TensorShape shape;
luci::sinf::Rule shape_inf_rule;

neg.x(nullptr);
ASSERT_ANY_THROW(shape_inf_rule.infer(&neg, shape));
}

0 comments on commit 67df4c5

Please sign in to comment.