From e30dc0a9334b5a9a0aa7564e1d8b07294b98f734 Mon Sep 17 00:00:00 2001 From: Zhengzhouheng Date: Fri, 22 Dec 2023 09:09:11 +0000 Subject: [PATCH] Support node trace db from tflite to tim vx Type: Code Improvement --- CMakeLists.txt | 5 +++ cmake/modules/Findtim-vx.cmake | 9 +++++ delegate_main.cc | 34 +++++++++++++++++ delegate_main.h | 11 ++++++ utils.cc | 69 +++++++++++++++++++++++++++++++++- utils.h | 8 ++++ 6 files changed, 135 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index f06aa05..e4728f7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -22,6 +22,7 @@ cmake_minimum_required(VERSION 3.16) option(TFLITE_ENABLE_MULTI_DEVICE "Enable multi devices support" OFF) option(TFLITE_ENABLE_OPTIMIZE "Enable optimize tiny yolov4" OFF) +option(TFLITE_ENABLE_NODE_TRACE "Enable node trace" OFF) if(TFLITE_ENABLE_OPTIMIZE) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DVSI_FEAT_OP_CUSTOM_TINY_YOLOV4_POSTPROCESS -DENABLE_TENSOR_CACHE") @@ -64,6 +65,10 @@ if(TFLITE_ENABLE_MULTI_DEVICE) ADD_DEFINITIONS(-DMULTI_DEVICE_FEATURE_MODE) endif() +if(TFLITE_ENABLE_NODE_TRACE) + ADD_DEFINITIONS(-DNODE_TRACE_DB_MODE) +endif() + add_library(vx_delegate SHARED ${VX_DELEGATES_SRCS}) list(APPEND VX_CUSTOM_OP_SRCS diff --git a/cmake/modules/Findtim-vx.cmake b/cmake/modules/Findtim-vx.cmake index 01efbaf..d0240e8 100644 --- a/cmake/modules/Findtim-vx.cmake +++ b/cmake/modules/Findtim-vx.cmake @@ -25,6 +25,9 @@ if(TFLITE_ENABLE_MULTI_DEVICE) set(TIM_VX_ENABLE_40BIT "ON") endif() +if(TFLITE_ENABLE_NODE_TRACE) + set(TIM_VX_ENABLE_NODE_TRACE "ON") +endif() if((NOT DEFINED TIM_VX_INSTALL)) if(TFLITE_ENABLE_MULTI_DEVICE AND (NOT EXTERNAL_VIV_SDK)) message(FATAL_ERROR "FATAL: multi device only suppot 40 bit driver, @@ -43,10 +46,16 @@ if((NOT DEFINED TIM_VX_INSTALL)) include_directories(${tim-vx_SOURCE_DIR}/include) add_subdirectory("${tim-vx_SOURCE_DIR}" "${tim-vx_BINARY_DIR}") + if(${TIM_VX_ENABLE_NODE_TRACE}) + list(APPEND VX_DELEGATE_DEPENDENCIES ${tim-vx_BINARY_DIR}/_deps/jsoncpp-build/src/lib_json/libjsoncpp.so) + endif() # list(APPEND VX_DELEGATE_DEPENDENCIES tim-vx) else() message("=== Building with TIM_VX_LIBRIRIES from ${TIM_VX_INSTALL} ===") include_directories(${TIM_VX_INSTALL}/include) set(LIBDIR lib) list(APPEND VX_DELEGATE_DEPENDENCIES ${TIM_VX_INSTALL}/${LIBDIR}/libtim-vx.so) + if(${TIM_VX_ENABLE_NODE_TRACE}) + list(APPEND VX_DELEGATE_DEPENDENCIES ${TIM_VX_INSTALL}/${LIBDIR}/libjsoncpp.so) + endif() endif() diff --git a/delegate_main.cc b/delegate_main.cc index f2e4980..33937c7 100644 --- a/delegate_main.cc +++ b/delegate_main.cc @@ -41,6 +41,10 @@ #include "tim/transform/layout_inference.h" #include "tim/transform/mean_stddev_normalize_fusion.h" +#ifdef NODE_TRACE_DB_MODE +#include "json/json.h" +#endif + using namespace tflite; namespace { @@ -529,6 +533,11 @@ TfLiteStatus Delegate::Invoke(const OpData& op_data, TfLiteContext* context, TfLiteNode* node) { TFLITE_LOG(TFLITE_LOG_INFO, "Delegate::Invoke node: %p", node->user_data); + +#ifdef NODE_TRACE_DB_MODE + std::vector tflite_node_id_map; +#endif + if (!compiled_) { // TODO(bo): Handling multi-thread use case context_ = tim::vx::Context::Create(); @@ -562,12 +571,29 @@ TfLiteStatus Delegate::Invoke(const OpData& op_data, auto& states = op_info.states; auto& builtin_data = op_info.builtin_data; +#ifdef NODE_TRACE_DB_MODE + vx::delegate::TfliteNodeIDPair tflite_node_id_pair; + std::vector> before_op_vector; + std::vector> after_op_vector; +#endif + std::vector inputs_outputs; std::copy( inputs.begin(), inputs.end(), std::back_inserter(inputs_outputs)); std::copy( outputs.begin(), outputs.end(), std::back_inserter(inputs_outputs)); +#ifdef NODE_TRACE_DB_MODE + tflite_node_id_pair.builtin_code = builtin_code; + std::copy( + inputs.begin(), inputs.end(), std::back_inserter(tflite_node_id_pair.inputs)); + std::copy( + outputs.begin(), outputs.end(), std::back_inserter(tflite_node_id_pair.outputs)); + std::copy( + this->GetGraph()->OpVector().begin(), this->GetGraph()->OpVector().end(), std::back_inserter(before_op_vector)); + tflite_node_id_map.push_back(tflite_node_id_pair); +#endif + for (size_t port_idx = 0; port_idx < inputs_outputs.size(); port_idx++) { int tensor_idx = inputs_outputs[port_idx]; if (-1 != tensor_idx && tensors_.find(tensor_idx) == tensors_.end()) { @@ -634,7 +660,15 @@ TfLiteStatus Delegate::Invoke(const OpData& op_data, states_tensors, builtin_data.data()); } +#ifdef NODE_TRACE_DB_MODE + std::copy( + this->GetGraph()->OpVector().begin(), this->GetGraph()->OpVector().end(), std::back_inserter(after_op_vector)); + vx::delegate::utils::MapTfliteNodeToTimVxNode(before_op_vector, after_op_vector, tflite_node_id_map); +#endif } +#ifdef NODE_TRACE_DB_MODE + vx::delegate::utils::GenerateVxNodeTraceDb(tflite_node_id_map); +#endif TFLITE_LOG(TFLITE_LOG_INFO, "Verifying graph"); // Do normalization op fusion before layout inference diff --git a/delegate_main.h b/delegate_main.h index 1aec32a..9084c29 100644 --- a/delegate_main.h +++ b/delegate_main.h @@ -68,6 +68,17 @@ typedef struct { bool error_during_invoke; } VxDelegateOptions; +#ifdef NODE_TRACE_DB_MODE +typedef struct +{ + //tflite node unique id + std::vector inputs; + std::vector outputs; + int builtin_code; + //tim wx node uid + std::vector op_uids; +}TfliteNodeIDPair; +#endif class Delegate; diff --git a/utils.cc b/utils.cc index 7898951..df235a6 100644 --- a/utils.cc +++ b/utils.cc @@ -25,6 +25,10 @@ #include "utils.h" #include "tensorflow/lite/minimal_logging.h" +#ifdef NODE_TRACE_DB_MODE +#include "json/json.h" +#endif + using namespace tflite; namespace vx { @@ -66,7 +70,6 @@ std::vector GetOvxTransposePerm(const std::vector& perm) { return ovx_perm; } - void GenerateWeightsDataForBilinear(float* data, const std::vector& weight_shape, uint32_t scale_w, @@ -109,6 +112,70 @@ void GenerateWeightDataForNearest(float* data, return; } +#ifdef NODE_TRACE_DB_MODE +void MapTfliteNodeToTimVxNode( + const std::vector>& before_op_vector, + const std::vector>& after_op_vector, + std::vector& tflite_node_id_map) { + size_t new_operation_size = after_op_vector.size() - before_op_vector.size(); + size_t i = 0; + std::vector new_operation; + if (new_operation_size <= 0 || tflite_node_id_map.size() == 0) { + return; + } + + for (i = 0; i < new_operation_size; i++) { + size_t new_operation_index = before_op_vector.size(); + uint32_t uid = after_op_vector[new_operation_index + i]->uid(); + tflite_node_id_map[tflite_node_id_map.size() - 1].op_uids.push_back(uid); + } + return; +} + +void GenerateVxNodeTraceDb( + std::vector& tflite_node_id_map) { + Json::Value root; + + Json::StyledWriter sw; + uint32_t i = 0; + std::fstream fs; + fs.open("vx_node_trace_db.json", std::ios::out | std::ios::trunc); + + for (auto tflite_node_id_pair : tflite_node_id_map) { + Json::Value tflite_node_uid; + Json::Value tim_vx_uids; + + Json::Value inputs_ids; + Json::Value outputs_ids; + Json::Value tflite_node_builtin_code; + + Json::Value map_pair; + for (i = 0; i < tflite_node_id_pair.inputs.size(); i++) { + inputs_ids[i] = tflite_node_id_pair.inputs[i]; + } + for (i = 0; i < tflite_node_id_pair.outputs.size(); i++) { + outputs_ids[i] = tflite_node_id_pair.outputs[i]; + } + tflite_node_builtin_code = tflite_node_id_pair.builtin_code; + tflite_node_uid["inputs"] = inputs_ids; + tflite_node_uid["outputs"] = outputs_ids; + tflite_node_uid["builtin_code"] = tflite_node_id_pair.builtin_code; + + for (i = 0; i < tflite_node_id_pair.op_uids.size(); i++) { + tim_vx_uids[i] = tflite_node_id_pair.op_uids[i]; + } + + map_pair["tflite_node_id"] = tflite_node_uid; + map_pair["tim_vx_uid"] = tim_vx_uids; + root.append(map_pair); + } + + fs << sw.write(root); + fs.close(); + return; +} +#endif + } // namespace utils } // namespace delegate } // namespace vx diff --git a/utils.h b/utils.h index 2a6d5a9..53d9cd7 100644 --- a/utils.h +++ b/utils.h @@ -79,6 +79,14 @@ void GenerateWeightsDataForBilinear(float* data, void GenerateWeightDataForNearest(float* data, const std::vector& weight_shape); +#ifdef NODE_TRACE_DB_MODE +void MapTfliteNodeToTimVxNode(const std::vector>& before_op_vector, + const std::vector>& after_op_vector, + std::vector& tflite_node_id_map); + +void GenerateVxNodeTraceDb(std::vector& tflite_node_id_map); +#endif + template inline void Quantize(const std::vector& data, float scale, int32_t zero_point, std::vector& quant_data) {