From d619e4995444111332c57e0195b3def8981ed6e7 Mon Sep 17 00:00:00 2001 From: Hyeongseok Oh Date: Fri, 12 Apr 2024 11:04:29 +0900 Subject: [PATCH] [onert] Introduce Pool2D dynamic shape inference (#12861) This commit adds Pool2D dynamic shape inference. ONE-DCO-1.0-Signed-off-by: Hyeongseok Oh --- .../core/include/exec/DynamicShapeInferer.h | 1 + .../core/src/exec/DynamicShapeInferer.cc | 20 +++++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/runtime/onert/core/include/exec/DynamicShapeInferer.h b/runtime/onert/core/include/exec/DynamicShapeInferer.h index f814b789a90..d34dc053b58 100644 --- a/runtime/onert/core/include/exec/DynamicShapeInferer.h +++ b/runtime/onert/core/include/exec/DynamicShapeInferer.h @@ -72,6 +72,7 @@ class DynamicShapeInferer : public ir::OperationVisitor void visit(const ir::operation::Pack &op) override; void visit(const ir::operation::Pad &op) override; void visit(const ir::operation::Permute &op) override; + void visit(const ir::operation::Pool2D &op) override; void visit(const ir::operation::Pow &op) override; // TODO write op starting from Q void visit(const ir::operation::Range &op) override; diff --git a/runtime/onert/core/src/exec/DynamicShapeInferer.cc b/runtime/onert/core/src/exec/DynamicShapeInferer.cc index bbd8b25506b..691a1193342 100644 --- a/runtime/onert/core/src/exec/DynamicShapeInferer.cc +++ b/runtime/onert/core/src/exec/DynamicShapeInferer.cc @@ -705,6 +705,26 @@ void DynamicShapeInferer::visit(const ir::operation::Permute & /* op */) // on-the-fly, as it must support inter-backend inference/allocation. } +void DynamicShapeInferer::visit(const ir::operation::Pool2D &op) +{ + // check if input is not dynamic + auto input_ind = op.getInputs().at(ir::operation::Pool2D::INPUT); + auto input = _tensor_registry->getITensor(input_ind); + + if (!input->is_dynamic()) + return; + + ir::Shape input_shape = input->getShape(); + + auto output_ind = op.getOutputs().at(0); + auto output = _tensor_registry->getITensor(output_ind); + + ir::Shape output_shape = shape_inference::inferPoolShape(input_shape, op.param()); + + output->applyShape(output_shape); + assert(output->buffer() != nullptr); +} + void DynamicShapeInferer::visit(const ir::operation::Pow &op) { handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::Pow::Input::LHS),