From 5cae6ffbfa4a3a93042f829dc95b9cda2afaff22 Mon Sep 17 00:00:00 2001 From: Jang Jiseob Date: Thu, 18 Jul 2024 14:59:42 +0900 Subject: [PATCH] [onert] Add generating training usedefs for Reshape op (#13460) This commit adds generating training usedefs for Reshape operation. ONE-DCO-1.0-Signed-off-by: ragmani --- .../onert/core/src/ir/train/UseDefGenerator.cc | 17 +++++++++++++++++ .../onert/core/src/ir/train/UseDefGenerator.h | 1 + 2 files changed, 18 insertions(+) diff --git a/runtime/onert/core/src/ir/train/UseDefGenerator.cc b/runtime/onert/core/src/ir/train/UseDefGenerator.cc index 615b1650c38..ea4e212b138 100644 --- a/runtime/onert/core/src/ir/train/UseDefGenerator.cc +++ b/runtime/onert/core/src/ir/train/UseDefGenerator.cc @@ -103,6 +103,23 @@ void UseDefGenerator::visit(const train::operation::Loss &node) usedef_chain.removeTrainingUse(backwarding_op_index); } +void UseDefGenerator::visit(const train::operation::Reshape &node) +{ + assert(_node_to_idx.find(&node) != _node_to_idx.end()); + const auto &op_index = _node_to_idx.at(&node); + const auto backwarding_op_index = TrainingOperationIndex{op_index, false}; + + // Insert use of backwarding(backprop) output + const auto &out_index = node.getOutputs().at(0); + const auto incoming_index = TrainingOperandIndex{out_index, false}; + insertUse(incoming_index, backwarding_op_index); + + // Set def of backwarding(backprop) input + const auto &in_index = node.getInputs().at(train::operation::Reduce::Input::INPUT); + const auto outgoing_index = TrainingOperandIndex{in_index, false}; + insertBackPropDef(outgoing_index, backwarding_op_index); +} + void UseDefGenerator::insertUse(const TrainingOperandIndex &operand_index, const TrainingOperationIndex &op_index) { diff --git a/runtime/onert/core/src/ir/train/UseDefGenerator.h b/runtime/onert/core/src/ir/train/UseDefGenerator.h index 369d9a22338..4fa89c0da9f 100644 --- a/runtime/onert/core/src/ir/train/UseDefGenerator.h +++ b/runtime/onert/core/src/ir/train/UseDefGenerator.h @@ -65,6 +65,7 @@ class UseDefGenerator : public UseDefGeneratorBase public: void visit(const train::operation::Loss &node) override; + void visit(const train::operation::Reshape &node) override; private: void insertUse(const TrainingOperandIndex &operand_index, const TrainingOperationIndex &op_index);