diff --git a/runtime/onert/core/src/ir/train/UseDefGenerator.cc b/runtime/onert/core/src/ir/train/UseDefGenerator.cc index ea4e212b138..b2ecfad5911 100644 --- a/runtime/onert/core/src/ir/train/UseDefGenerator.cc +++ b/runtime/onert/core/src/ir/train/UseDefGenerator.cc @@ -71,6 +71,44 @@ UseDefChains UseDefGenerator::operator()() return _training_usedefs; } +void UseDefGenerator::visit(const train::operation::Conv2D &node) +{ + assert(_node_to_idx.find(&node) != _node_to_idx.end()); + const auto &op_index = _node_to_idx.at(&node); + const auto backwarding_op_index = TrainingOperationIndex{op_index, false}; + + // Insert use of forwarding inputs + const auto &in_index = node.getInputs().at(train::operation::Conv2D::Input::INPUT); + const auto in_forwarding_index = TrainingOperandIndex{in_index, true}; + insertUse(in_forwarding_index, backwarding_op_index); + + const auto &weights_index = node.getInputs().at(train::operation::Conv2D::Input::KERNEL); + const auto weights_forwarding_index = TrainingOperandIndex{weights_index, true}; + insertUse(weights_forwarding_index, backwarding_op_index); + + // Insert use of forwarding output + if (node.param().activation != ir::Activation::NONE) + { + const auto &out_index = node.getOutputs().at(0); + const auto out_forwarding_index = TrainingOperandIndex{out_index, true}; + insertUse(out_forwarding_index, backwarding_op_index); + } + + // Set def of backwarding inputs + const auto outgoing_index = TrainingOperandIndex{in_index, false}; + insertBackPropDef(outgoing_index, backwarding_op_index); + + const auto weights_gradient_index = TrainingOperandIndex{weights_index, false}; + insertDef(weights_gradient_index, backwarding_op_index); + + const auto &bias_index = node.getInputs().at(ir::operation::Conv2D::Input::BIAS); + if (bias_index.valid()) + { + const auto bias_gradient_index = TrainingOperandIndex{bias_index, false}; + insertDef(bias_gradient_index, backwarding_op_index); + } +} + void UseDefGenerator::visit(const train::operation::Loss &node) { assert(_node_to_idx.find(&node) != _node_to_idx.end()); diff --git a/runtime/onert/core/src/ir/train/UseDefGenerator.h b/runtime/onert/core/src/ir/train/UseDefGenerator.h index 4fa89c0da9f..b889a2f5554 100644 --- a/runtime/onert/core/src/ir/train/UseDefGenerator.h +++ b/runtime/onert/core/src/ir/train/UseDefGenerator.h @@ -64,6 +64,7 @@ class UseDefGenerator : public UseDefGeneratorBase UseDefChains operator()(); public: + void visit(const train::operation::Conv2D &node) override; void visit(const train::operation::Loss &node) override; void visit(const train::operation::Reshape &node) override;