Skip to content

Commit

Permalink
Support node trace db from tflite to tim vx
Browse files Browse the repository at this point in the history
Type: Code Improvement
  • Loading branch information
zhengzhouheng authored and sunshinemyson committed Dec 26, 2023
1 parent 91471fc commit e30dc0a
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 1 deletion.
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ cmake_minimum_required(VERSION 3.16)

option(TFLITE_ENABLE_MULTI_DEVICE "Enable multi devices support" OFF)
option(TFLITE_ENABLE_OPTIMIZE "Enable optimize tiny yolov4" OFF)
option(TFLITE_ENABLE_NODE_TRACE "Enable node trace" OFF)

if(TFLITE_ENABLE_OPTIMIZE)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DVSI_FEAT_OP_CUSTOM_TINY_YOLOV4_POSTPROCESS -DENABLE_TENSOR_CACHE")
Expand Down Expand Up @@ -64,6 +65,10 @@ if(TFLITE_ENABLE_MULTI_DEVICE)
ADD_DEFINITIONS(-DMULTI_DEVICE_FEATURE_MODE)
endif()

if(TFLITE_ENABLE_NODE_TRACE)
ADD_DEFINITIONS(-DNODE_TRACE_DB_MODE)
endif()

add_library(vx_delegate SHARED ${VX_DELEGATES_SRCS})

list(APPEND VX_CUSTOM_OP_SRCS
Expand Down
9 changes: 9 additions & 0 deletions cmake/modules/Findtim-vx.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ if(TFLITE_ENABLE_MULTI_DEVICE)
set(TIM_VX_ENABLE_40BIT "ON")
endif()

if(TFLITE_ENABLE_NODE_TRACE)
set(TIM_VX_ENABLE_NODE_TRACE "ON")
endif()
if((NOT DEFINED TIM_VX_INSTALL))
if(TFLITE_ENABLE_MULTI_DEVICE AND (NOT EXTERNAL_VIV_SDK))
message(FATAL_ERROR "FATAL: multi device only suppot 40 bit driver,
Expand All @@ -43,10 +46,16 @@ if((NOT DEFINED TIM_VX_INSTALL))
include_directories(${tim-vx_SOURCE_DIR}/include)
add_subdirectory("${tim-vx_SOURCE_DIR}"
"${tim-vx_BINARY_DIR}")
if(${TIM_VX_ENABLE_NODE_TRACE})
list(APPEND VX_DELEGATE_DEPENDENCIES ${tim-vx_BINARY_DIR}/_deps/jsoncpp-build/src/lib_json/libjsoncpp.so)
endif()
# list(APPEND VX_DELEGATE_DEPENDENCIES tim-vx)
else()
message("=== Building with TIM_VX_LIBRIRIES from ${TIM_VX_INSTALL} ===")
include_directories(${TIM_VX_INSTALL}/include)
set(LIBDIR lib)
list(APPEND VX_DELEGATE_DEPENDENCIES ${TIM_VX_INSTALL}/${LIBDIR}/libtim-vx.so)
if(${TIM_VX_ENABLE_NODE_TRACE})
list(APPEND VX_DELEGATE_DEPENDENCIES ${TIM_VX_INSTALL}/${LIBDIR}/libjsoncpp.so)
endif()
endif()
34 changes: 34 additions & 0 deletions delegate_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@
#include "tim/transform/layout_inference.h"
#include "tim/transform/mean_stddev_normalize_fusion.h"

#ifdef NODE_TRACE_DB_MODE
#include "json/json.h"
#endif

using namespace tflite;
namespace {

Expand Down Expand Up @@ -529,6 +533,11 @@ TfLiteStatus Delegate::Invoke(const OpData& op_data,
TfLiteContext* context,
TfLiteNode* node) {
TFLITE_LOG(TFLITE_LOG_INFO, "Delegate::Invoke node: %p", node->user_data);

#ifdef NODE_TRACE_DB_MODE
std::vector<vx::delegate::TfliteNodeIDPair> tflite_node_id_map;
#endif

if (!compiled_) {
// TODO(bo): Handling multi-thread use case
context_ = tim::vx::Context::Create();
Expand Down Expand Up @@ -562,12 +571,29 @@ TfLiteStatus Delegate::Invoke(const OpData& op_data,
auto& states = op_info.states;
auto& builtin_data = op_info.builtin_data;

#ifdef NODE_TRACE_DB_MODE
vx::delegate::TfliteNodeIDPair tflite_node_id_pair;
std::vector<std::shared_ptr<tim::vx::Operation>> before_op_vector;
std::vector<std::shared_ptr<tim::vx::Operation>> after_op_vector;
#endif

std::vector<int> inputs_outputs;
std::copy(
inputs.begin(), inputs.end(), std::back_inserter(inputs_outputs));
std::copy(
outputs.begin(), outputs.end(), std::back_inserter(inputs_outputs));

#ifdef NODE_TRACE_DB_MODE
tflite_node_id_pair.builtin_code = builtin_code;
std::copy(
inputs.begin(), inputs.end(), std::back_inserter(tflite_node_id_pair.inputs));
std::copy(
outputs.begin(), outputs.end(), std::back_inserter(tflite_node_id_pair.outputs));
std::copy(
this->GetGraph()->OpVector().begin(), this->GetGraph()->OpVector().end(), std::back_inserter(before_op_vector));
tflite_node_id_map.push_back(tflite_node_id_pair);
#endif

for (size_t port_idx = 0; port_idx < inputs_outputs.size(); port_idx++) {
int tensor_idx = inputs_outputs[port_idx];
if (-1 != tensor_idx && tensors_.find(tensor_idx) == tensors_.end()) {
Expand Down Expand Up @@ -634,7 +660,15 @@ TfLiteStatus Delegate::Invoke(const OpData& op_data,
states_tensors,
builtin_data.data());
}
#ifdef NODE_TRACE_DB_MODE
std::copy(
this->GetGraph()->OpVector().begin(), this->GetGraph()->OpVector().end(), std::back_inserter(after_op_vector));
vx::delegate::utils::MapTfliteNodeToTimVxNode(before_op_vector, after_op_vector, tflite_node_id_map);
#endif
}
#ifdef NODE_TRACE_DB_MODE
vx::delegate::utils::GenerateVxNodeTraceDb(tflite_node_id_map);
#endif

TFLITE_LOG(TFLITE_LOG_INFO, "Verifying graph");
// Do normalization op fusion before layout inference
Expand Down
11 changes: 11 additions & 0 deletions delegate_main.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,17 @@ typedef struct {
bool error_during_invoke;
} VxDelegateOptions;

#ifdef NODE_TRACE_DB_MODE
typedef struct
{
//tflite node unique id
std::vector<int> inputs;
std::vector<int> outputs;
int builtin_code;
//tim wx node uid
std::vector<uint32_t> op_uids;
}TfliteNodeIDPair;
#endif

class Delegate;

Expand Down
69 changes: 68 additions & 1 deletion utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
#include "utils.h"
#include "tensorflow/lite/minimal_logging.h"

#ifdef NODE_TRACE_DB_MODE
#include "json/json.h"
#endif

using namespace tflite;

namespace vx {
Expand Down Expand Up @@ -66,7 +70,6 @@ std::vector<uint32_t> GetOvxTransposePerm(const std::vector<uint32_t>& perm) {
return ovx_perm;
}


void GenerateWeightsDataForBilinear(float* data,
const std::vector<uint32_t>& weight_shape,
uint32_t scale_w,
Expand Down Expand Up @@ -109,6 +112,70 @@ void GenerateWeightDataForNearest(float* data,
return;
}

#ifdef NODE_TRACE_DB_MODE
void MapTfliteNodeToTimVxNode(
const std::vector<std::shared_ptr<tim::vx::Operation>>& before_op_vector,
const std::vector<std::shared_ptr<tim::vx::Operation>>& after_op_vector,
std::vector<vx::delegate::TfliteNodeIDPair>& tflite_node_id_map) {
size_t new_operation_size = after_op_vector.size() - before_op_vector.size();
size_t i = 0;
std::vector<uint32_t> new_operation;
if (new_operation_size <= 0 || tflite_node_id_map.size() == 0) {
return;
}

for (i = 0; i < new_operation_size; i++) {
size_t new_operation_index = before_op_vector.size();
uint32_t uid = after_op_vector[new_operation_index + i]->uid();
tflite_node_id_map[tflite_node_id_map.size() - 1].op_uids.push_back(uid);
}
return;
}

void GenerateVxNodeTraceDb(
std::vector<vx::delegate::TfliteNodeIDPair>& tflite_node_id_map) {
Json::Value root;

Json::StyledWriter sw;
uint32_t i = 0;
std::fstream fs;
fs.open("vx_node_trace_db.json", std::ios::out | std::ios::trunc);

for (auto tflite_node_id_pair : tflite_node_id_map) {
Json::Value tflite_node_uid;
Json::Value tim_vx_uids;

Json::Value inputs_ids;
Json::Value outputs_ids;
Json::Value tflite_node_builtin_code;

Json::Value map_pair;
for (i = 0; i < tflite_node_id_pair.inputs.size(); i++) {
inputs_ids[i] = tflite_node_id_pair.inputs[i];
}
for (i = 0; i < tflite_node_id_pair.outputs.size(); i++) {
outputs_ids[i] = tflite_node_id_pair.outputs[i];
}
tflite_node_builtin_code = tflite_node_id_pair.builtin_code;
tflite_node_uid["inputs"] = inputs_ids;
tflite_node_uid["outputs"] = outputs_ids;
tflite_node_uid["builtin_code"] = tflite_node_id_pair.builtin_code;

for (i = 0; i < tflite_node_id_pair.op_uids.size(); i++) {
tim_vx_uids[i] = tflite_node_id_pair.op_uids[i];
}

map_pair["tflite_node_id"] = tflite_node_uid;
map_pair["tim_vx_uid"] = tim_vx_uids;
root.append(map_pair);
}

fs << sw.write(root);
fs.close();
return;
}
#endif

} // namespace utils
} // namespace delegate
} // namespace vx
8 changes: 8 additions & 0 deletions utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,14 @@ void GenerateWeightsDataForBilinear(float* data,
void GenerateWeightDataForNearest(float* data,
const std::vector<uint32_t>& weight_shape);

#ifdef NODE_TRACE_DB_MODE
void MapTfliteNodeToTimVxNode(const std::vector<std::shared_ptr<tim::vx::Operation>>& before_op_vector,
const std::vector<std::shared_ptr<tim::vx::Operation>>& after_op_vector,
std::vector<vx::delegate::TfliteNodeIDPair>& tflite_node_id_map);

void GenerateVxNodeTraceDb(std::vector<vx::delegate::TfliteNodeIDPair>& tflite_node_id_map);
#endif

template <typename T>
inline void Quantize(const std::vector<float>& data, float scale,
int32_t zero_point, std::vector<T>& quant_data) {
Expand Down

0 comments on commit e30dc0a

Please sign in to comment.