diff --git a/runtime/onert/core/src/ir/train/UseDefGenerator.cc b/runtime/onert/core/src/ir/train/UseDefGenerator.cc index 615b1650c38..ea4e212b138 100644 --- a/runtime/onert/core/src/ir/train/UseDefGenerator.cc +++ b/runtime/onert/core/src/ir/train/UseDefGenerator.cc @@ -103,6 +103,23 @@ void UseDefGenerator::visit(const train::operation::Loss &node) usedef_chain.removeTrainingUse(backwarding_op_index); } +void UseDefGenerator::visit(const train::operation::Reshape &node) +{ + assert(_node_to_idx.find(&node) != _node_to_idx.end()); + const auto &op_index = _node_to_idx.at(&node); + const auto backwarding_op_index = TrainingOperationIndex{op_index, false}; + + // Insert use of backwarding(backprop) output + const auto &out_index = node.getOutputs().at(0); + const auto incoming_index = TrainingOperandIndex{out_index, false}; + insertUse(incoming_index, backwarding_op_index); + + // Set def of backwarding(backprop) input + const auto &in_index = node.getInputs().at(train::operation::Reduce::Input::INPUT); + const auto outgoing_index = TrainingOperandIndex{in_index, false}; + insertBackPropDef(outgoing_index, backwarding_op_index); +} + void UseDefGenerator::insertUse(const TrainingOperandIndex &operand_index, const TrainingOperationIndex &op_index) { diff --git a/runtime/onert/core/src/ir/train/UseDefGenerator.h b/runtime/onert/core/src/ir/train/UseDefGenerator.h index 369d9a22338..4fa89c0da9f 100644 --- a/runtime/onert/core/src/ir/train/UseDefGenerator.h +++ b/runtime/onert/core/src/ir/train/UseDefGenerator.h @@ -65,6 +65,7 @@ class UseDefGenerator : public UseDefGeneratorBase public: void visit(const train::operation::Loss &node) override; + void visit(const train::operation::Reshape &node) override; private: void insertUse(const TrainingOperandIndex &operand_index, const TrainingOperationIndex &op_index);