Skip to content

Commit

Permalink
[luci/service] migrate Range shape inference rule to sinf::Algorithm (#…
Browse files Browse the repository at this point in the history
…13987)

This commit migrate Range shape inference rule to sinf::Algorithm.

ONE-DCO-1.0-Signed-off-by: bokyeong Lee <[email protected]>
  • Loading branch information
kyeong8139 authored Sep 12, 2024
1 parent c2b372c commit 3a3e50b
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ class Algorithm final : public luci::CircleNodeVisitor<loco::TensorShape>
// loco::TensorShape visit(const luci::CirclePow *node) final;
// loco::TensorShape visit(const luci::CirclePRelu *node) final;
loco::TensorShape visit(const luci::CircleQuantize *node) final;
// loco::TensorShape visit(const luci::CircleRange *node) final;
loco::TensorShape visit(const luci::CircleRange *node) final;
// loco::TensorShape visit(const luci::CircleRank *node) final;
// loco::TensorShape visit(const luci::CircleReduceAny *node) final;
// loco::TensorShape visit(const luci::CircleReduceMax *node) final;
Expand Down
45 changes: 0 additions & 45 deletions compiler/luci/service/src/CircleShapeInferenceRule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -911,49 +911,6 @@ loco::NodeShape infer_p_relu(const luci::CirclePRelu *node)
return loco::NodeShape{output_shape};
}

loco::NodeShape infer_range(const luci::CircleRange *node)
{
loco::TensorShape output_shape;
output_shape.rank(1);

auto start_node = dynamic_cast<luci::CircleConst *>(node->start());
auto limit_node = dynamic_cast<luci::CircleConst *>(node->limit());
auto delta_node = dynamic_cast<luci::CircleConst *>(node->delta());

if (start_node == nullptr || limit_node == nullptr || delta_node == nullptr)
{
return use_own(node);
}

double start = 0, limit = 0, delta = 0;

#define GET_RANGE_PARAM(DT) \
start = start_node->scalar<DT>(); \
limit = limit_node->scalar<DT>(); \
delta = delta_node->scalar<DT>();

switch (start_node->dtype())
{
case loco::DataType::FLOAT32:
GET_RANGE_PARAM(loco::DataType::FLOAT32)
break;
case loco::DataType::S32:
GET_RANGE_PARAM(loco::DataType::S32)
break;
default:
INTERNAL_EXN("Range data type not supported");
}

#undef GET_RANGE_PARAM

if (delta == 0)
INTERNAL_EXN("Delta can not be zero");

output_shape.dim(0) = ceil((limit - start) / delta);

return loco::NodeShape{output_shape};
}

loco::NodeShape infer_reshape(const luci::CircleReshape *node)
{
LOGGER(l);
Expand Down Expand Up @@ -2104,8 +2061,6 @@ class ShapeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::NodeS

loco::NodeShape visit(const luci::CirclePRelu *node) final { return infer_p_relu(node); }

loco::NodeShape visit(const luci::CircleRange *node) final { return infer_range(node); }

loco::NodeShape visit(const luci::CircleRank *) final
{
loco::TensorShape shape_output;
Expand Down
67 changes: 67 additions & 0 deletions compiler/luci/service/src/Nodes/CircleRange.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@
* limitations under the License.
*/

#include "luci/Service/CircleShapeInference.h"

#include "CircleCloneNode.h"
#include "CircleShapeInferenceHelper.h"

#include <cmath>

namespace luci
{
Expand All @@ -24,4 +29,66 @@ luci::CircleNode *CloneNodeLet<CN::OPQR>::visit(const luci::CircleRange *)
return _graph->nodes()->create<luci::CircleRange>();
}

namespace sinf
{

loco::TensorShape Algorithm::visit(const luci::CircleRange *node)
{
loco::TensorShape output_shape;
output_shape.rank(1);

auto start_node = dynamic_cast<luci::CircleConst *>(node->start());
auto limit_node = dynamic_cast<luci::CircleConst *>(node->limit());
auto delta_node = dynamic_cast<luci::CircleConst *>(node->delta());

if (start_node == nullptr || limit_node == nullptr || delta_node == nullptr)
{
// We use shape from the node itself
loco::TensorShape shape;
shape.rank(node->rank());
for (uint32_t r = 0; r < node->rank(); ++r)
{
// TODO remove this copy from `use_own(node);`
// Shape inference rules in this file did not consider unknown dimension.
// If some node has unknown dimension, 0 is inserted and wrong shape
// inference was done as a result.
// To fix this, new shape inference algorithm is being implemented.
// Until new inference algorithm is fully implemented, unknown dimension
// would be represented as 1 along with TFLite expression.
shape.dim(r) = node->dim(r).known() ? node->dim(r).value() : 1;
}
return shape;
}

double start = 0, limit = 0, delta = 0;

#define GET_RANGE_PARAM(DT) \
start = start_node->scalar<DT>(); \
limit = limit_node->scalar<DT>(); \
delta = delta_node->scalar<DT>();

switch (start_node->dtype())
{
case loco::DataType::FLOAT32:
GET_RANGE_PARAM(loco::DataType::FLOAT32)
break;
case loco::DataType::S32:
GET_RANGE_PARAM(loco::DataType::S32)
break;
default:
INTERNAL_EXN("Range data type not supported");
}

#undef GET_RANGE_PARAM

if (delta == 0)
INTERNAL_EXN("Delta can not be zero");

output_shape.dim(0) = ceil((limit - start) / delta);

return output_shape;
}

} // namespace sinf

} // namespace luci
64 changes: 64 additions & 0 deletions compiler/luci/service/src/Nodes/CircleRange.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#include "luci/Service/CircleNodeClone.h"
#include "luci/Service/CircleShapeInference.h"

#include <gtest/gtest.h>

Expand All @@ -31,3 +32,66 @@ TEST(CloneNodeTest, clone_Range)
auto cloned_range = dynamic_cast<luci::CircleRange *>(cloned);
ASSERT_NE(nullptr, cloned_range);
}

TEST(ShapeRuleTest, range_const_param)
{
luci::CircleConst start, limit, delta;
luci::CircleRange range;

start.dtype(loco::DataType::S32);
start.size<loco::DataType::S32>(1);
start.at<loco::DataType::S32>(0) = 0;
start.shape_status(luci::ShapeStatus::VALID);

limit.dtype(loco::DataType::S32);
limit.size<loco::DataType::S32>(1);
limit.at<loco::DataType::S32>(0) = 10;
limit.shape_status(luci::ShapeStatus::VALID);

delta.dtype(loco::DataType::S32);
delta.size<loco::DataType::S32>(1);
delta.at<loco::DataType::S32>(0) = 2;
delta.shape_status(luci::ShapeStatus::VALID);

range.start(&start);
range.limit(&limit);
range.delta(&delta);

loco::TensorShape shape;
luci::sinf::Rule shape_inf_rule;

ASSERT_TRUE(shape_inf_rule.infer(&range, shape));
ASSERT_EQ(1, shape.rank());
ASSERT_TRUE(shape.dim(0).known());
ASSERT_EQ(5, shape.dim(0).value());
}

TEST(ShapeRuleTest, range_zero_delta_NEG)
{
luci::CircleConst start, limit, delta;
luci::CircleRange range;

start.dtype(loco::DataType::S32);
start.size<loco::DataType::S32>(1);
start.at<loco::DataType::S32>(0) = 0;
start.shape_status(luci::ShapeStatus::VALID);

limit.dtype(loco::DataType::S32);
limit.size<loco::DataType::S32>(1);
limit.at<loco::DataType::S32>(0) = 10;
limit.shape_status(luci::ShapeStatus::VALID);

delta.dtype(loco::DataType::S32);
delta.size<loco::DataType::S32>(1);
delta.at<loco::DataType::S32>(0) = 0;
delta.shape_status(luci::ShapeStatus::VALID);

range.start(&start);
range.limit(&limit);
range.delta(&delta);

loco::TensorShape shape;
luci::sinf::Rule shape_inf_rule;

ASSERT_ANY_THROW(shape_inf_rule.infer(&range, shape));
}

0 comments on commit 3a3e50b

Please sign in to comment.