From 6a6308b948c91ccfbc63f3f83060d96abf432e29 Mon Sep 17 00:00:00 2001 From: mbencer Date: Thu, 13 Jun 2024 05:26:34 +0200 Subject: [PATCH] [tools/onert_train] Add trainable ops support (#13162) This commits adds new flags of onert_train - trainable_ops_idx. It allows to pass trainable ops indexes nnfw training. ONE-DCO-1.0-Signed-off-by: Mateusz Bencer --- tests/tools/onert_train/README.md | 3 ++ tests/tools/onert_train/src/args.cc | 42 ++++++++++++++++++++++ tests/tools/onert_train/src/args.h | 3 ++ tests/tools/onert_train/src/onert_train.cc | 7 ++++ 4 files changed, 55 insertions(+) diff --git a/tests/tools/onert_train/README.md b/tests/tools/onert_train/README.md index 4a798f8f0e3..3f3142d1f2c 100644 --- a/tests/tools/onert_train/README.md +++ b/tests/tools/onert_train/README.md @@ -31,6 +31,7 @@ onert_train \ --learning_rate 0.01 \ --loss 2 \ # cateogrical crossentropy --loss_reduction_type 1 # sum over batch size +--trainable_ops_idx 0-30 # indexes of operations which should be trained ``` `onert_train --help` would help you to set each parameter. @@ -79,6 +80,7 @@ If you start with tensorflow code, you could first save it as saved format and t Now you're ready to run `onert_train`.
Please pass your model file to `--modelfile` and data files to `--load_input:raw` and `--load_expected:raw`.
Also, you could set training parameter using options like `--batch_size`, `--epoch`.. etc. +Please pay special attention for `trainable_ops_idx` to determine operations which should be trained. ```bash $ onert_train \ @@ -91,4 +93,5 @@ $ onert_train \ --learning_rate 0.001 \ --loss 2 \ # cateogrical crossentropy --loss_reduction_type 1 # sum over batch size +--trainable_ops_idx 0-10 ``` diff --git a/tests/tools/onert_train/src/args.cc b/tests/tools/onert_train/src/args.cc index 895060a9c3f..447efba60f9 100644 --- a/tests/tools/onert_train/src/args.cc +++ b/tests/tools/onert_train/src/args.cc @@ -20,6 +20,7 @@ #include #include +#include #include #include @@ -229,6 +230,44 @@ void Args::Initialize(void) } }; + auto process_trainable_ops = [&](const std::string &trainable_ops_idx_str) { + std::stringstream ss(trainable_ops_idx_str); + std::string token; + try + { + while (std::getline(ss, token, ',')) + { + const auto range_iter = token.find("-"); + if (range_iter != std::string::npos) + { + const auto begin_idx = std::stoi(token.substr(0, range_iter)); + const auto end_idx = std::stoi(token.substr(range_iter + 1, token.size())); + if (begin_idx > end_idx) + { + std::cerr << "begin_idx=" << begin_idx + << " of trainable operator index cannot be greater than end_idx=" << end_idx + << "\n"; + exit(1); + } + std::vector range(end_idx - begin_idx + 1); + std::iota(std::begin(range), std::end(range), begin_idx); + _trainable_ops_idx.insert(std::begin(range), std::end(range)); + } + _trainable_ops_idx.emplace(std::stoi(token)); + } + } + catch (const std::invalid_argument &ex) + { + std::cerr << "Invalid argument passed as trainable_ops_idx: " << ex.what() << "\n"; + exit(1); + } + catch (const std::out_of_range &ex) + { + std::cerr << "Out of range argument passed as trainable_ops_idx: " << ex.what() << "\n"; + exit(1); + } + }; + // General options po::options_description general("General options", 100); @@ -285,6 +324,9 @@ void Args::Initialize(void) "The output buffer size in JSON 1D array\n" "If not given, the model's output sizes are used\n" "e.g. '[0, 40, 2, 80]' to set 0th tensor to 40 and 2nd tensor to 80.") + ("trainable_ops_idx", po::value()->notifier(process_trainable_ops), + "Indexes of trainable nodes in the graph (indexes numeration starts with 0). " + "The indexes can be passed as a comma-separated list (like 65,68,70) or in a range form (like 60-70).") ; // clang-format on diff --git a/tests/tools/onert_train/src/args.h b/tests/tools/onert_train/src/args.h index 75557b106d2..de99720abee 100644 --- a/tests/tools/onert_train/src/args.h +++ b/tests/tools/onert_train/src/args.h @@ -21,6 +21,7 @@ #include #include #include +#include #include #include "nnfw_experimental.h" @@ -69,6 +70,7 @@ class Args const bool printVersion(void) const { return _print_version; } const int getVerboseLevel(void) const { return _verbose_level; } std::unordered_map getOutputSizes(void) const { return _output_sizes; } + std::set getTrainableOpsIdx(void) const { return _trainable_ops_idx; } private: void Initialize(); @@ -115,6 +117,7 @@ class Args bool _print_version = false; int _verbose_level; std::unordered_map _output_sizes; + std::set _trainable_ops_idx; }; } // end of namespace onert_train diff --git a/tests/tools/onert_train/src/onert_train.cc b/tests/tools/onert_train/src/onert_train.cc index 66f0edd507e..b2dad98de34 100644 --- a/tests/tools/onert_train/src/onert_train.cc +++ b/tests/tools/onert_train/src/onert_train.cc @@ -145,6 +145,13 @@ int main(const int argc, char **argv) args.getLossReductionType().value_or(tri.loss_info.reduction_type); tri.opt = args.getOptimizerType().value_or(tri.opt); + size_t pos = 0; + tri.trainble_ops_size = args.getTrainableOpsIdx().size(); + for (auto const &idx : args.getTrainableOpsIdx()) + { + tri.trainble_ops_idx[pos++] = idx; + } + std::cout << "== training parameter ==" << std::endl; std::cout << tri; std::cout << "========================" << std::endl;