Skip to content

Commit

Permalink
[luci/service] migrate Range shape inference rule to sinf::Algorithm
Browse files Browse the repository at this point in the history
This commit migrate Range shape inference rule to sinf::Algorithm.

ONE-DCO-1.0-Signed-off-by: Bokyeong Lee <[email protected]>
  • Loading branch information
kyeong8139 committed Sep 11, 2024
1 parent 8a66a21 commit ce3d380
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ class Algorithm final : public luci::CircleNodeVisitor<loco::TensorShape>
// loco::TensorShape visit(const luci::CirclePow *node) final;
// loco::TensorShape visit(const luci::CirclePRelu *node) final;
loco::TensorShape visit(const luci::CircleQuantize *node) final;
// loco::TensorShape visit(const luci::CircleRange *node) final;
loco::TensorShape visit(const luci::CircleRange *node) final;
// loco::TensorShape visit(const luci::CircleRank *node) final;
// loco::TensorShape visit(const luci::CircleReduceAny *node) final;
// loco::TensorShape visit(const luci::CircleReduceMax *node) final;
Expand Down
45 changes: 0 additions & 45 deletions compiler/luci/service/src/CircleShapeInferenceRule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -911,49 +911,6 @@ loco::NodeShape infer_p_relu(const luci::CirclePRelu *node)
return loco::NodeShape{output_shape};
}

loco::NodeShape infer_range(const luci::CircleRange *node)
{
loco::TensorShape output_shape;
output_shape.rank(1);

auto start_node = dynamic_cast<luci::CircleConst *>(node->start());
auto limit_node = dynamic_cast<luci::CircleConst *>(node->limit());
auto delta_node = dynamic_cast<luci::CircleConst *>(node->delta());

if (start_node == nullptr || limit_node == nullptr || delta_node == nullptr)
{
return use_own(node);
}

double start = 0, limit = 0, delta = 0;

#define GET_RANGE_PARAM(DT) \
start = start_node->scalar<DT>(); \
limit = limit_node->scalar<DT>(); \
delta = delta_node->scalar<DT>();

switch (start_node->dtype())
{
case loco::DataType::FLOAT32:
GET_RANGE_PARAM(loco::DataType::FLOAT32)
break;
case loco::DataType::S32:
GET_RANGE_PARAM(loco::DataType::S32)
break;
default:
INTERNAL_EXN("Range data type not supported");
}

#undef GET_RANGE_PARAM

if (delta == 0)
INTERNAL_EXN("Delta can not be zero");

output_shape.dim(0) = ceil((limit - start) / delta);

return loco::NodeShape{output_shape};
}

loco::NodeShape infer_reshape(const luci::CircleReshape *node)
{
LOGGER(l);
Expand Down Expand Up @@ -2104,8 +2061,6 @@ class ShapeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::NodeS

loco::NodeShape visit(const luci::CirclePRelu *node) final { return infer_p_relu(node); }

loco::NodeShape visit(const luci::CircleRange *node) final { return infer_range(node); }

loco::NodeShape visit(const luci::CircleRank *) final
{
loco::TensorShape shape_output;
Expand Down
48 changes: 48 additions & 0 deletions compiler/luci/service/src/Nodes/CircleRange.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,52 @@ luci::CircleNode *CloneNodeLet<CN::OPQR>::visit(const luci::CircleRange *)
return _graph->nodes()->create<luci::CircleRange>();
}

namespace sinf
{

loco::TensorShape Algorithm::visit(const luci::CircleRange *node)
{
loco::TensorShape output_shape;
output_shape.rank(1);

auto start_node = dynamic_cast<luci::CircleConst *>(node->start());
auto limit_node = dynamic_cast<luci::CircleConst *>(node->limit());
auto delta_node = dynamic_cast<luci::CircleConst *>(node->delta());

if (start_node == nullptr || limit_node == nullptr || delta_node == nullptr)
{
return output_shape;
}

double start = 0, limit = 0, delta = 0;

#define GET_RANGE_PARAM(DT) \
start = start_node->scalar<DT>(); \
limit = limit_node->scalar<DT>(); \
delta = delta_node->scalar<DT>();

switch (start_node->dtype())
{
case loco::DataType::FLOAT32:
GET_RANGE_PARAM(loco::DataType::FLOAT32)
break;
case loco::DataType::S32:
GET_RANGE_PARAM(loco::DataType::S32)
break;
default:
INTERNAL_EXN("Range data type not supported");
}

#undef GET_RANGE_PARAM

if (delta == 0)
INTERNAL_EXN("Delta can not be zero");

output_shape.dim(0) = ceil((limit - start) / delta);

return output_shape;
}

} // namespace sinf

} // namespace luci

0 comments on commit ce3d380

Please sign in to comment.