Skip to content

Commit

Permalink
[tools/onert_train] Add trainable ops support (Samsung#13162)
Browse files Browse the repository at this point in the history
This commits adds new flags of onert_train - trainable_ops_idx.
It allows to pass trainable ops indexes nnfw training.

ONE-DCO-1.0-Signed-off-by: Mateusz Bencer <[email protected]>
  • Loading branch information
mbencer authored Jun 13, 2024
1 parent b11d6be commit 6a6308b
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 0 deletions.
3 changes: 3 additions & 0 deletions tests/tools/onert_train/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ onert_train \
--learning_rate 0.01 \
--loss 2 \ # cateogrical crossentropy
--loss_reduction_type 1 # sum over batch size
--trainable_ops_idx 0-30 # indexes of operations which should be trained
```

`onert_train --help` would help you to set each parameter.
Expand Down Expand Up @@ -79,6 +80,7 @@ If you start with tensorflow code, you could first save it as saved format and t
Now you're ready to run `onert_train`. <br/>
Please pass your model file to `--modelfile` and data files to `--load_input:raw` and `--load_expected:raw`. <br/>
Also, you could set training parameter using options like `--batch_size`, `--epoch`.. etc.
Please pay special attention for `trainable_ops_idx` to determine operations which should be trained.

```bash
$ onert_train \
Expand All @@ -91,4 +93,5 @@ $ onert_train \
--learning_rate 0.001 \
--loss 2 \ # cateogrical crossentropy
--loss_reduction_type 1 # sum over batch size
--trainable_ops_idx 0-10
```
42 changes: 42 additions & 0 deletions tests/tools/onert_train/src/args.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include <functional>
#include <iostream>
#include <numeric>
#include <sys/stat.h>
#include <json/json.h>

Expand Down Expand Up @@ -229,6 +230,44 @@ void Args::Initialize(void)
}
};

auto process_trainable_ops = [&](const std::string &trainable_ops_idx_str) {
std::stringstream ss(trainable_ops_idx_str);
std::string token;
try
{
while (std::getline(ss, token, ','))
{
const auto range_iter = token.find("-");
if (range_iter != std::string::npos)
{
const auto begin_idx = std::stoi(token.substr(0, range_iter));
const auto end_idx = std::stoi(token.substr(range_iter + 1, token.size()));
if (begin_idx > end_idx)
{
std::cerr << "begin_idx=" << begin_idx
<< " of trainable operator index cannot be greater than end_idx=" << end_idx
<< "\n";
exit(1);
}
std::vector<uint32_t> range(end_idx - begin_idx + 1);
std::iota(std::begin(range), std::end(range), begin_idx);
_trainable_ops_idx.insert(std::begin(range), std::end(range));
}
_trainable_ops_idx.emplace(std::stoi(token));
}
}
catch (const std::invalid_argument &ex)
{
std::cerr << "Invalid argument passed as trainable_ops_idx: " << ex.what() << "\n";
exit(1);
}
catch (const std::out_of_range &ex)
{
std::cerr << "Out of range argument passed as trainable_ops_idx: " << ex.what() << "\n";
exit(1);
}
};

// General options
po::options_description general("General options", 100);

Expand Down Expand Up @@ -285,6 +324,9 @@ void Args::Initialize(void)
"The output buffer size in JSON 1D array\n"
"If not given, the model's output sizes are used\n"
"e.g. '[0, 40, 2, 80]' to set 0th tensor to 40 and 2nd tensor to 80.")
("trainable_ops_idx", po::value<std::string>()->notifier(process_trainable_ops),
"Indexes of trainable nodes in the graph (indexes numeration starts with 0). "
"The indexes can be passed as a comma-separated list (like 65,68,70) or in a range form (like 60-70).")
;
// clang-format on

Expand Down
3 changes: 3 additions & 0 deletions tests/tools/onert_train/src/args.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <string>
#include <unordered_map>
#include <vector>
#include <set>
#include <boost/program_options.hpp>

#include "nnfw_experimental.h"
Expand Down Expand Up @@ -69,6 +70,7 @@ class Args
const bool printVersion(void) const { return _print_version; }
const int getVerboseLevel(void) const { return _verbose_level; }
std::unordered_map<uint32_t, uint32_t> getOutputSizes(void) const { return _output_sizes; }
std::set<uint32_t> getTrainableOpsIdx(void) const { return _trainable_ops_idx; }

private:
void Initialize();
Expand Down Expand Up @@ -115,6 +117,7 @@ class Args
bool _print_version = false;
int _verbose_level;
std::unordered_map<uint32_t, uint32_t> _output_sizes;
std::set<uint32_t> _trainable_ops_idx;
};

} // end of namespace onert_train
Expand Down
7 changes: 7 additions & 0 deletions tests/tools/onert_train/src/onert_train.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,13 @@ int main(const int argc, char **argv)
args.getLossReductionType().value_or(tri.loss_info.reduction_type);
tri.opt = args.getOptimizerType().value_or(tri.opt);

size_t pos = 0;
tri.trainble_ops_size = args.getTrainableOpsIdx().size();
for (auto const &idx : args.getTrainableOpsIdx())
{
tri.trainble_ops_idx[pos++] = idx;
}

std::cout << "== training parameter ==" << std::endl;
std::cout << tri;
std::cout << "========================" << std::endl;
Expand Down

0 comments on commit 6a6308b

Please sign in to comment.