Skip to content

Commit

Permalink
[onert] Add generating training usedefs for Conv2D op (Samsung#13453)
Browse files Browse the repository at this point in the history
This commit adds generating training usedefs for Conv2D operation.

ONE-DCO-1.0-Signed-off-by: ragmani <[email protected]>
  • Loading branch information
ragmani authored Jul 18, 2024
1 parent 5cae6ff commit 19c259d
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 0 deletions.
38 changes: 38 additions & 0 deletions runtime/onert/core/src/ir/train/UseDefGenerator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,44 @@ UseDefChains UseDefGenerator::operator()()
return _training_usedefs;
}

void UseDefGenerator::visit(const train::operation::Conv2D &node)
{
assert(_node_to_idx.find(&node) != _node_to_idx.end());
const auto &op_index = _node_to_idx.at(&node);
const auto backwarding_op_index = TrainingOperationIndex{op_index, false};

// Insert use of forwarding inputs
const auto &in_index = node.getInputs().at(train::operation::Conv2D::Input::INPUT);
const auto in_forwarding_index = TrainingOperandIndex{in_index, true};
insertUse(in_forwarding_index, backwarding_op_index);

const auto &weights_index = node.getInputs().at(train::operation::Conv2D::Input::KERNEL);
const auto weights_forwarding_index = TrainingOperandIndex{weights_index, true};
insertUse(weights_forwarding_index, backwarding_op_index);

// Insert use of forwarding output
if (node.param().activation != ir::Activation::NONE)
{
const auto &out_index = node.getOutputs().at(0);
const auto out_forwarding_index = TrainingOperandIndex{out_index, true};
insertUse(out_forwarding_index, backwarding_op_index);
}

// Set def of backwarding inputs
const auto outgoing_index = TrainingOperandIndex{in_index, false};
insertBackPropDef(outgoing_index, backwarding_op_index);

const auto weights_gradient_index = TrainingOperandIndex{weights_index, false};
insertDef(weights_gradient_index, backwarding_op_index);

const auto &bias_index = node.getInputs().at(ir::operation::Conv2D::Input::BIAS);
if (bias_index.valid())
{
const auto bias_gradient_index = TrainingOperandIndex{bias_index, false};
insertDef(bias_gradient_index, backwarding_op_index);
}
}

void UseDefGenerator::visit(const train::operation::Loss &node)
{
assert(_node_to_idx.find(&node) != _node_to_idx.end());
Expand Down
1 change: 1 addition & 0 deletions runtime/onert/core/src/ir/train/UseDefGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class UseDefGenerator : public UseDefGeneratorBase
UseDefChains operator()();

public:
void visit(const train::operation::Conv2D &node) override;
void visit(const train::operation::Loss &node) override;
void visit(const train::operation::Reshape &node) override;

Expand Down

0 comments on commit 19c259d

Please sign in to comment.