diff --git a/compiler/luci/service/include/luci/Service/CircleShapeInference.h b/compiler/luci/service/include/luci/Service/CircleShapeInference.h index ca23c3e814f..176390cf40b 100644 --- a/compiler/luci/service/include/luci/Service/CircleShapeInference.h +++ b/compiler/luci/service/include/luci/Service/CircleShapeInference.h @@ -101,7 +101,7 @@ class Algorithm final : public luci::CircleNodeVisitor // loco::TensorShape visit(const luci::CircleMinimum *node) final; // loco::TensorShape visit(const luci::CircleMirrorPad *node) final; loco::TensorShape visit(const luci::CircleMul *node) final; - // loco::TensorShape visit(const luci::CircleNeg *node) final; + loco::TensorShape visit(const luci::CircleNeg *node) final; // loco::TensorShape visit(const luci::CircleNonMaxSuppressionV4 *node) final; // loco::TensorShape visit(const luci::CircleNonMaxSuppressionV5 *node) final; // loco::TensorShape visit(const luci::CircleNotEqual *node) final; diff --git a/compiler/luci/service/src/CircleShapeInferenceRule.cpp b/compiler/luci/service/src/CircleShapeInferenceRule.cpp index 31e4c146855..2514533696f 100644 --- a/compiler/luci/service/src/CircleShapeInferenceRule.cpp +++ b/compiler/luci/service/src/CircleShapeInferenceRule.cpp @@ -2142,8 +2142,6 @@ class ShapeInferenceAlgorithm final : public luci::CircleNodeVisitorboxes()).as(); diff --git a/compiler/luci/service/src/Nodes/CircleNeg.cpp b/compiler/luci/service/src/Nodes/CircleNeg.cpp index 20190fd891f..25d2cb16dd8 100644 --- a/compiler/luci/service/src/Nodes/CircleNeg.cpp +++ b/compiler/luci/service/src/Nodes/CircleNeg.cpp @@ -13,8 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "luci/Service/CircleShapeInference.h" #include "CircleCloneNode.h" +#include "CircleShapeInferenceHelper.h" namespace luci { @@ -24,4 +26,16 @@ luci::CircleNode *CloneNodeLet::visit(const luci::CircleNeg *) return _graph->nodes()->create(); } +namespace sinf +{ + +loco::TensorShape Algorithm::visit(const luci::CircleNeg *node) +{ + const auto input_x = loco::must_cast(node->x()); + const auto input_shape = circle_shape(input_x); + return input_shape; +} + +} // namespace sinf + } // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleNeg.test.cpp b/compiler/luci/service/src/Nodes/CircleNeg.test.cpp index 8c288032453..42a0dcd2ec9 100644 --- a/compiler/luci/service/src/Nodes/CircleNeg.test.cpp +++ b/compiler/luci/service/src/Nodes/CircleNeg.test.cpp @@ -15,6 +15,7 @@ */ #include "luci/Service/CircleNodeClone.h" +#include "luci/Service/CircleShapeInference.h" #include @@ -31,3 +32,41 @@ TEST(CloneNodeTest, clone_Neg) auto cloned_neg = dynamic_cast(cloned); ASSERT_NE(nullptr, cloned_neg); } + +TEST(ShapeRuleTest, Neg_dynamic_shape) +{ + luci::CircleInput input; + luci::CircleNeg neg; + + loco::TensorShape shape; + luci::sinf::Rule shape_inf_rule; + + input.shape({1, 1, 3, 4}); + input.shape_status(luci::ShapeStatus::VALID); + input.dim(1).unset(); + + neg.x(&input); + + ASSERT_TRUE(shape_inf_rule.infer(&neg, shape)); + ASSERT_EQ(shape.rank(), 4); + ASSERT_TRUE(shape.dim(0).known()); + ASSERT_FALSE(shape.dim(1).known()); + ASSERT_TRUE(shape.dim(2).known()); + ASSERT_TRUE(shape.dim(3).known()); + + ASSERT_EQ(1, shape.dim(0).value()); + ASSERT_EQ(0, shape.dim(1).value()); + ASSERT_EQ(3, shape.dim(2).value()); + ASSERT_EQ(4, shape.dim(3).value()); +} + +TEST(ShapeRuleTest, Neg_nullptr_input_NEG) +{ + luci::CircleNeg neg; + + loco::TensorShape shape; + luci::sinf::Rule shape_inf_rule; + + neg.x(nullptr); + ASSERT_ANY_THROW(shape_inf_rule.infer(&neg, shape)); +}