forked from Samsung/ONE
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[CFE/minmax-embedder] Introduce minmax-embedder library (Samsung#11722)
* [CFE/minmax-embedder] Introduce minmax-embedder library It introduces minmax-embedder, which embedds minmax recording in hdf5 to circle. ONE-DCO-1.0-Signed-off-by: Sanggyu Lee <[email protected]>
- Loading branch information
1 parent
b5cef18
commit 4990a29
Showing
8 changed files
with
443 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
nnas_find_package(HDF5 COMPONENTS STATIC QUIET) | ||
|
||
if(NOT HDF5_FOUND) | ||
message(STATUS "Build minmax_embedder: FAILED (missing HDF5)") | ||
return() | ||
endif(NOT HDF5_FOUND) | ||
|
||
file(GLOB_RECURSE SOURCES "src/*.cpp") | ||
file(GLOB_RECURSE TESTS "src/*.test.cpp") | ||
list(REMOVE_ITEM SOURCES ${TESTS}) | ||
|
||
# | ||
# Library | ||
# | ||
add_library(minmax_embedder "${SOURCES}") | ||
target_include_directories(minmax_embedder PUBLIC ${HDF5_INCLUDE_DIRS}) | ||
target_include_directories(minmax_embedder PRIVATE include) | ||
|
||
target_link_libraries(minmax_embedder ${HDF5_CXX_LIBRARIES}) | ||
target_link_libraries(minmax_embedder loco) | ||
target_link_libraries(minmax_embedder luci_import) | ||
target_link_libraries(minmax_embedder luci_service) | ||
target_link_libraries(minmax_embedder luci_pass) | ||
target_link_libraries(minmax_embedder luci_export) | ||
target_link_libraries(minmax_embedder luci_env) | ||
|
||
install(TARGETS minmax_embedder DESTINATION lib) | ||
install(DIRECTORY include/ DESTINATION include | ||
FILES_MATCHING PATTERN "*.h") | ||
# | ||
# GTest | ||
# | ||
if(NOT ENABLE_TEST) | ||
return() | ||
endif(NOT ENABLE_TEST) | ||
|
||
nnas_find_package(GTest REQUIRED) | ||
|
||
GTest_AddTest(minmax_embedder_test ${TESTS}) | ||
target_include_directories(minmax_embedder_test PRIVATE include) | ||
target_link_libraries(minmax_embedder_test minmax_embedder) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# minmax-embedder | ||
|
||
_minmax-embedder_ embeds minmax to circle. |
39 changes: 39 additions & 0 deletions
39
compiler/minmax-embedder/include/minmax-embedder/Embedder.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
/* | ||
* Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#ifndef __MINMAX_EMBEDDER_EMBEDDER_H__ | ||
#define __MINMAX_EMBEDDER_EMBEDDER_H__ | ||
|
||
#include <string> | ||
|
||
namespace minmax_embedder | ||
{ | ||
|
||
struct EmbedderOptions | ||
{ | ||
float min_percentile = 0.0f; // dummy initial value to make SE tool happy | ||
float max_percentile = 0.0f; // dummy initial value To make SE tool happy | ||
}; | ||
|
||
class Embedder | ||
{ | ||
public: | ||
void embed(const std::string &output_path, const std::string &input_path, | ||
const std::string &minmax_path, const EmbedderOptions &); | ||
}; | ||
} // namespace minmax_embedder | ||
|
||
#endif // __MINMAX_EMBEDDER_EMBEDDER_H__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
require("loco") | ||
require("luci") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
/* | ||
* Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include "minmax-embedder/Embedder.h" | ||
|
||
#include <luci/CircleExporter.h> | ||
#include <luci/CircleFileExpContract.h> | ||
#include <luci/ImporterEx.h> | ||
#include <luci/IR/CircleNode.h> | ||
#include <luci/IR/CircleQuantParam.h> | ||
#include <luci/Profile/CircleNodeID.h> | ||
#include <luci/Service/Validate.h> | ||
|
||
#include "h5/Reader.h" | ||
|
||
#include <cassert> | ||
#include <cmath> // for std::floor | ||
#include <iostream> | ||
#include <string> | ||
|
||
namespace | ||
{ | ||
|
||
/* NOTE: getNthPercentile is copied from compiler/record-minmax/include/RecordFunction.h */ | ||
/** | ||
* @brief getNthPercentile calculates the n-th percentile of input vector (0.0 <= n <= 100.0) | ||
* linear interpolation is used when the desired percentile lies between two data points | ||
*/ | ||
float getNthPercentile(std::vector<float> &vector, float percentile) | ||
{ | ||
if (percentile < 0 || percentile > 100) | ||
throw std::runtime_error("Percentile must be ranged from 0 to 100"); | ||
|
||
if (vector.empty()) | ||
throw std::runtime_error("Percentile must take a non-empty vector as an argument"); | ||
|
||
if (vector.size() == 1) | ||
return vector[0]; | ||
|
||
std::vector<float> copy; | ||
copy.assign(vector.begin(), vector.end()); | ||
std::sort(copy.begin(), copy.end()); | ||
|
||
if (percentile == 0.0) | ||
return copy.front(); | ||
|
||
if (percentile == 100.0) | ||
return copy.back(); | ||
|
||
int index = static_cast<int>(std::floor((copy.size() - 1) * percentile / 100.0)); | ||
|
||
float percent_i = static_cast<float>(index) / static_cast<float>(copy.size() - 1); | ||
float fraction = | ||
(percentile / 100.0 - percent_i) / ((index + 1.0) / (copy.size() - 1.0) - percent_i); | ||
float res = copy[index] + fraction * (copy[index + 1] - copy[index]); | ||
return res; | ||
} | ||
|
||
} // namespace | ||
|
||
namespace minmax_embedder | ||
{ | ||
|
||
void Embedder::embed(const std::string &output_path, const std::string &input_path, | ||
const std::string &minmax_path, const EmbedderOptions &opt) | ||
{ | ||
// Load model from the file | ||
luci::ImporterEx importerex; | ||
auto module = importerex.importVerifyModule(input_path); | ||
if (module.get() == nullptr) | ||
throw std::runtime_error{"Input circle is invalid"}; | ||
|
||
h5::Reader mmr{minmax_path}; | ||
|
||
for (size_t idx = 0; idx < module->size(); ++idx) | ||
{ | ||
auto graph = module->graph(idx); | ||
|
||
/* read subgraph inputs */ | ||
const auto input_nodes = loco::input_nodes(graph); | ||
const auto n_inputs = input_nodes.size(); | ||
for (size_t input_idx = 0; input_idx < n_inputs; ++input_idx) | ||
{ | ||
const auto *circle_input = loco::must_cast<const luci::CircleInput *>(input_nodes[input_idx]); | ||
if (circle_input->index() != input_idx) | ||
throw std::runtime_error("Input order in minmax recording does not match to circle"); | ||
|
||
auto minmax = mmr.read_input(0, idx, input_idx); | ||
auto min = getNthPercentile(minmax.min_vector, opt.min_percentile); | ||
auto max = getNthPercentile(minmax.max_vector, opt.max_percentile); | ||
auto quantparam = std::make_unique<luci::CircleQuantParam>(); | ||
quantparam->min.push_back(min); | ||
quantparam->max.push_back(max); | ||
const auto *circle_node = loco::must_cast<const luci::CircleNode *>(input_nodes[input_idx]); | ||
auto mutable_node = const_cast<luci::CircleNode *>(circle_node); | ||
mutable_node->quantparam(std::move(quantparam)); | ||
} | ||
|
||
/* read op outputs */ | ||
uint32_t n_nodes = graph->nodes()->size(); | ||
for (uint32_t i = 0; i < n_nodes; ++i) | ||
{ | ||
auto node = loco::must_cast<luci::CircleNode *>(graph->nodes()->at(i)); | ||
if (not luci::has_node_id(node)) // Skip non-op nodes (e.g. input/const/output) | ||
continue; | ||
auto op_idx = luci::get_node_id(node); | ||
auto minmax = mmr.read(0, idx, op_idx); | ||
auto min = getNthPercentile(minmax.min_vector, opt.min_percentile); | ||
auto max = getNthPercentile(minmax.max_vector, opt.max_percentile); | ||
auto quantparam = std::make_unique<luci::CircleQuantParam>(); | ||
quantparam->min.push_back(min); | ||
quantparam->max.push_back(max); | ||
auto mutable_node = const_cast<luci::CircleNode *>(node); | ||
mutable_node->quantparam(std::move(quantparam)); | ||
} | ||
|
||
if (!luci::validate(graph)) | ||
throw std::runtime_error{"Circle after embedding minmax is invalid"}; | ||
} | ||
|
||
// Export to output Circle file | ||
luci::CircleExporter exporter; | ||
|
||
luci::CircleFileExpContract contract(module.get(), output_path); | ||
|
||
if (!exporter.invoke(&contract)) | ||
throw std::runtime_error{"Failed to export circle"}; | ||
} | ||
|
||
} // namespace minmax_embedder |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
/* | ||
* Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include "minmax-embedder/Embedder.h" | ||
|
||
#include <gtest/gtest.h> | ||
|
||
using namespace minmax_embedder; | ||
|
||
namespace | ||
{ | ||
struct MinMaxEmbedderTest : public ::testing::Test | ||
{ | ||
EmbedderOptions _opt{0, 100}; | ||
}; | ||
|
||
} // namespace | ||
|
||
TEST_F(MinMaxEmbedderTest, invalid_input_NEG) | ||
{ | ||
Embedder embedder; | ||
EXPECT_THROW(embedder.embed("", "not_existing", "", _opt), std::runtime_error); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
/* | ||
* Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include "Reader.h" | ||
|
||
#include <cstdio> | ||
#include <stdexcept> | ||
|
||
namespace | ||
{ | ||
bool exists(hid_t id, const char *path) { return H5Lexists(id, path, H5P_DEFAULT) > 0; } | ||
} // namespace | ||
|
||
namespace minmax_embedder | ||
{ | ||
namespace h5 | ||
{ | ||
static const char *h5_value_grpname = "value"; | ||
|
||
Reader::Reader(const std::string &filepath) : _file(filepath, H5F_ACC_RDONLY) | ||
{ | ||
_val_grp = _file.openGroup(h5_value_grpname); | ||
} | ||
|
||
// TODO: Handle multiple output | ||
MinMaxVectors Reader::read(int model_idx, int subg_idx, int op_idx) const | ||
{ | ||
MinMaxVectors mmv; | ||
float minmax[2]; | ||
auto num_run = _val_grp.getNumObjs(); | ||
for (uint32_t r = 0; r < num_run; ++r) | ||
{ | ||
// check whether minmax exists | ||
char path[128]; // 128 is enough to print "/value/run_%d/model_%d/subg_%d/op_%d" + null | ||
snprintf(path, 128, "/value/run_%d/model_%d/subg_%d/op_%d", r, model_idx, subg_idx, op_idx); | ||
if (!exists(_file.getId(), path)) | ||
continue; | ||
auto run_grp = _val_grp.openGroup(std::string("run_") + std::to_string(r)); | ||
auto model_grp = run_grp.openGroup(std::string("model_") + std::to_string(model_idx)); | ||
auto subg_grp = model_grp.openGroup(std::string("subg_") + std::to_string(subg_idx)); | ||
auto op_dset = subg_grp.openDataSet(std::string("op_") + std::to_string(op_idx)); | ||
H5::DataType dtype = op_dset.getDataType(); | ||
if (not(dtype == H5::PredType::IEEE_F32BE || dtype == H5::PredType::IEEE_F32LE)) | ||
throw std::runtime_error{"dtype of min, max in h5 is not float."}; | ||
op_dset.read(minmax, H5::PredType::NATIVE_FLOAT); | ||
mmv.min_vector.emplace_back(minmax[0]); | ||
mmv.max_vector.emplace_back(minmax[1]); | ||
} | ||
return mmv; | ||
} | ||
|
||
MinMaxVectors Reader::read_input(int model_idx, int subg_idx, int input_idx) const | ||
{ | ||
MinMaxVectors mmv; | ||
float minmax[2]; | ||
auto num_run = _val_grp.getNumObjs(); | ||
for (uint32_t r = 0; r < num_run; ++r) | ||
{ | ||
// check whether minmax exists | ||
char path[128]; // 128 is enough to print "/value/run_%d/model_%d/subg_%d/input_%d" + null | ||
snprintf(path, 128, "/value/run_%d/model_%d/subg_%d/input_%d", r, model_idx, subg_idx, | ||
input_idx); | ||
if (!exists(_file.getId(), path)) | ||
continue; | ||
auto run_grp = _val_grp.openGroup(std::string("run_") + std::to_string(r)); | ||
auto model_grp = run_grp.openGroup(std::string("model_") + std::to_string(model_idx)); | ||
auto subg_grp = model_grp.openGroup(std::string("subg_") + std::to_string(subg_idx)); | ||
auto op_dset = subg_grp.openDataSet(std::string("input_") + std::to_string(input_idx)); | ||
|
||
H5::DataType dtype = op_dset.getDataType(); | ||
if (not(dtype == H5::PredType::IEEE_F32BE || dtype == H5::PredType::IEEE_F32LE)) | ||
throw std::runtime_error{"dtype of min, max in h5 is not float."}; | ||
op_dset.read(minmax, H5::PredType::NATIVE_FLOAT); | ||
mmv.min_vector.emplace_back(minmax[0]); | ||
mmv.max_vector.emplace_back(minmax[1]); | ||
} | ||
return mmv; | ||
} | ||
|
||
} // namespace h5 | ||
} // namespace minmax_embedder |
Oops, something went wrong.