Skip to content

Commit

Permalink
[onert/training] Introduce TensorPlanner (Samsung#13433)
Browse files Browse the repository at this point in the history
This commit introduces TensorPlanner that plans all tensors associated with training.

ONE-DCO-1.0-Signed-off-by: ragmani <[email protected]>
  • Loading branch information
ragmani authored Jul 18, 2024
1 parent cf27d66 commit 8e19a4a
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 0 deletions.
65 changes: 65 additions & 0 deletions runtime/onert/backend/train/TensorPlanner.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "TensorPlanner.h"

#include <util/logging.h>

namespace onert
{
namespace backend
{
namespace train
{

TensorPlanner::TensorPlanner(const ir::train::TrainableGraph &tgraph,
const util::Set<ir::OperandIndex> &external_operands)
: _tgraph{tgraph}, _external_operands{external_operands}
{
// DO NOTHING
// TODO Remove the following lines
UNUSED_RELEASE(_tgraph);
UNUSED_RELEASE(_external_operands);
}

void TensorPlanner::planNonConstTensors(TensorBuilder *)
{
// TODO Plan non-const tensors
}

void TensorPlanner::planTrainableTensors(TensorBuilder *)
{
// TODO Plan trainable tensors such as weights
}

void TensorPlanner::planBackPropTensors(TensorBuilder *)
{
// TODO Plan back-propagated tensors
}

void TensorPlanner::planGradientTensors(TensorBuilder *)
{
// TODO Plan gradient tensors
}

void TensorPlanner::planDisposableBackPropTensors(TensorBuilder *)
{
// TODO Plan diposable backprop tensors
}

} // namespace train
} // namespace backend
} // namespace onert
58 changes: 58 additions & 0 deletions runtime/onert/backend/train/TensorPlanner.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef __ONERT_BACKEND_TRAIN_TENSOR_PLANNER_H__
#define __ONERT_BACKEND_TRAIN_TENSOR_PLANNER_H__

#include "TensorBuilder.h"

#include <ir/train/TrainableGraph.h>
#include <util/Set.h>

namespace onert
{
namespace backend
{
namespace train
{

class TensorPlanner
{
public:
TensorPlanner(const ir::train::TrainableGraph &tgraph,
const util::Set<ir::OperandIndex> &external_operands);
TensorPlanner(const TensorPlanner &) = delete;
TensorPlanner(TensorPlanner &&) = delete;
TensorPlanner &operator=(const TensorPlanner &) = delete;
TensorPlanner &operator=(TensorPlanner &&) = delete;
~TensorPlanner() = default;

void planNonConstTensors(TensorBuilder *tensor_builder);
void planTrainableTensors(TensorBuilder *tensor_builder);
void planBackPropTensors(TensorBuilder *tensor_builder);
void planGradientTensors(TensorBuilder *tensor_builder);
void planDisposableBackPropTensors(TensorBuilder *tensor_builder);

private:
const ir::train::TrainableGraph &_tgraph;
const util::Set<ir::OperandIndex> &_external_operands;
};

} // namespace train
} // namespace backend
} // namespace onert

#endif // __ONERT_BACKEND_TRAIN_TENSOR_PLANNER_H__

0 comments on commit 8e19a4a

Please sign in to comment.