Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support node trace db from tflite to tim vx #205

Merged
merged 1 commit into from
Dec 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ cmake_minimum_required(VERSION 3.16)

option(TFLITE_ENABLE_MULTI_DEVICE "Enable multi devices support" OFF)
option(TFLITE_ENABLE_OPTIMIZE "Enable optimize tiny yolov4" OFF)
option(TFLITE_ENABLE_NODE_TRACE "Enable node trace" OFF)

if(TFLITE_ENABLE_OPTIMIZE)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DVSI_FEAT_OP_CUSTOM_TINY_YOLOV4_POSTPROCESS -DENABLE_TENSOR_CACHE")
Expand Down Expand Up @@ -64,6 +65,10 @@ if(TFLITE_ENABLE_MULTI_DEVICE)
ADD_DEFINITIONS(-DMULTI_DEVICE_FEATURE_MODE)
endif()

if(TFLITE_ENABLE_NODE_TRACE)
ADD_DEFINITIONS(-DNODE_TRACE_DB_MODE)
endif()

add_library(vx_delegate SHARED ${VX_DELEGATES_SRCS})

list(APPEND VX_CUSTOM_OP_SRCS
Expand Down
9 changes: 9 additions & 0 deletions cmake/modules/Findtim-vx.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ if(TFLITE_ENABLE_MULTI_DEVICE)
set(TIM_VX_ENABLE_40BIT "ON")
endif()

if(TFLITE_ENABLE_NODE_TRACE)
set(TIM_VX_ENABLE_NODE_TRACE "ON")
endif()
if((NOT DEFINED TIM_VX_INSTALL))
if(TFLITE_ENABLE_MULTI_DEVICE AND (NOT EXTERNAL_VIV_SDK))
message(FATAL_ERROR "FATAL: multi device only suppot 40 bit driver,
Expand All @@ -43,10 +46,16 @@ if((NOT DEFINED TIM_VX_INSTALL))
include_directories(${tim-vx_SOURCE_DIR}/include)
add_subdirectory("${tim-vx_SOURCE_DIR}"
"${tim-vx_BINARY_DIR}")
if(${TIM_VX_ENABLE_NODE_TRACE})
list(APPEND VX_DELEGATE_DEPENDENCIES ${tim-vx_BINARY_DIR}/_deps/jsoncpp-build/src/lib_json/libjsoncpp.so)
endif()
# list(APPEND VX_DELEGATE_DEPENDENCIES tim-vx)
else()
message("=== Building with TIM_VX_LIBRIRIES from ${TIM_VX_INSTALL} ===")
include_directories(${TIM_VX_INSTALL}/include)
set(LIBDIR lib)
list(APPEND VX_DELEGATE_DEPENDENCIES ${TIM_VX_INSTALL}/${LIBDIR}/libtim-vx.so)
if(${TIM_VX_ENABLE_NODE_TRACE})
list(APPEND VX_DELEGATE_DEPENDENCIES ${TIM_VX_INSTALL}/${LIBDIR}/libjsoncpp.so)
endif()
endif()
34 changes: 34 additions & 0 deletions delegate_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@
#include "tim/transform/layout_inference.h"
#include "tim/transform/mean_stddev_normalize_fusion.h"

#ifdef NODE_TRACE_DB_MODE
#include "json/json.h"
#endif

using namespace tflite;
namespace {

Expand Down Expand Up @@ -529,6 +533,11 @@ TfLiteStatus Delegate::Invoke(const OpData& op_data,
TfLiteContext* context,
TfLiteNode* node) {
TFLITE_LOG(TFLITE_LOG_INFO, "Delegate::Invoke node: %p", node->user_data);

#ifdef NODE_TRACE_DB_MODE
std::vector<vx::delegate::TfliteNodeIDPair> tflite_node_id_map;
#endif

if (!compiled_) {
// TODO(bo): Handling multi-thread use case
context_ = tim::vx::Context::Create();
Expand Down Expand Up @@ -562,12 +571,29 @@ TfLiteStatus Delegate::Invoke(const OpData& op_data,
auto& states = op_info.states;
auto& builtin_data = op_info.builtin_data;

#ifdef NODE_TRACE_DB_MODE
vx::delegate::TfliteNodeIDPair tflite_node_id_pair;
std::vector<std::shared_ptr<tim::vx::Operation>> before_op_vector;
std::vector<std::shared_ptr<tim::vx::Operation>> after_op_vector;
#endif

std::vector<int> inputs_outputs;
std::copy(
inputs.begin(), inputs.end(), std::back_inserter(inputs_outputs));
std::copy(
outputs.begin(), outputs.end(), std::back_inserter(inputs_outputs));

#ifdef NODE_TRACE_DB_MODE
tflite_node_id_pair.builtin_code = builtin_code;
std::copy(
inputs.begin(), inputs.end(), std::back_inserter(tflite_node_id_pair.inputs));
std::copy(
outputs.begin(), outputs.end(), std::back_inserter(tflite_node_id_pair.outputs));
std::copy(
this->GetGraph()->OpVector().begin(), this->GetGraph()->OpVector().end(), std::back_inserter(before_op_vector));
tflite_node_id_map.push_back(tflite_node_id_pair);
#endif

for (size_t port_idx = 0; port_idx < inputs_outputs.size(); port_idx++) {
int tensor_idx = inputs_outputs[port_idx];
if (-1 != tensor_idx && tensors_.find(tensor_idx) == tensors_.end()) {
Expand Down Expand Up @@ -634,7 +660,15 @@ TfLiteStatus Delegate::Invoke(const OpData& op_data,
states_tensors,
builtin_data.data());
}
#ifdef NODE_TRACE_DB_MODE
std::copy(
this->GetGraph()->OpVector().begin(), this->GetGraph()->OpVector().end(), std::back_inserter(after_op_vector));
vx::delegate::utils::MapTfliteNodeToTimVxNode(before_op_vector, after_op_vector, tflite_node_id_map);
#endif
}
#ifdef NODE_TRACE_DB_MODE
vx::delegate::utils::GenerateVxNodeTraceDb(tflite_node_id_map);
#endif

TFLITE_LOG(TFLITE_LOG_INFO, "Verifying graph");
// Do normalization op fusion before layout inference
Expand Down
11 changes: 11 additions & 0 deletions delegate_main.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,17 @@ typedef struct {
bool error_during_invoke;
} VxDelegateOptions;

#ifdef NODE_TRACE_DB_MODE
typedef struct
{
//tflite node unique id
std::vector<int> inputs;
std::vector<int> outputs;
int builtin_code;
//tim wx node uid
std::vector<uint32_t> op_uids;
}TfliteNodeIDPair;
#endif

class Delegate;

Expand Down
69 changes: 68 additions & 1 deletion utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
#include "utils.h"
#include "tensorflow/lite/minimal_logging.h"

#ifdef NODE_TRACE_DB_MODE
#include "json/json.h"
#endif

using namespace tflite;

namespace vx {
Expand Down Expand Up @@ -66,7 +70,6 @@ std::vector<uint32_t> GetOvxTransposePerm(const std::vector<uint32_t>& perm) {
return ovx_perm;
}


void GenerateWeightsDataForBilinear(float* data,
const std::vector<uint32_t>& weight_shape,
uint32_t scale_w,
Expand Down Expand Up @@ -109,6 +112,70 @@ void GenerateWeightDataForNearest(float* data,
return;
}

#ifdef NODE_TRACE_DB_MODE
void MapTfliteNodeToTimVxNode(
const std::vector<std::shared_ptr<tim::vx::Operation>>& before_op_vector,
const std::vector<std::shared_ptr<tim::vx::Operation>>& after_op_vector,
std::vector<vx::delegate::TfliteNodeIDPair>& tflite_node_id_map) {
size_t new_operation_size = after_op_vector.size() - before_op_vector.size();
size_t i = 0;
std::vector<uint32_t> new_operation;
if (new_operation_size <= 0 || tflite_node_id_map.size() == 0) {
return;
}

for (i = 0; i < new_operation_size; i++) {
size_t new_operation_index = before_op_vector.size();
uint32_t uid = after_op_vector[new_operation_index + i]->uid();
tflite_node_id_map[tflite_node_id_map.size() - 1].op_uids.push_back(uid);
}
return;
}

void GenerateVxNodeTraceDb(
std::vector<vx::delegate::TfliteNodeIDPair>& tflite_node_id_map) {
Json::Value root;

Json::StyledWriter sw;
uint32_t i = 0;
std::fstream fs;
fs.open("vx_node_trace_db.json", std::ios::out | std::ios::trunc);

for (auto tflite_node_id_pair : tflite_node_id_map) {
Json::Value tflite_node_uid;
Json::Value tim_vx_uids;

Json::Value inputs_ids;
Json::Value outputs_ids;
Json::Value tflite_node_builtin_code;

Json::Value map_pair;
for (i = 0; i < tflite_node_id_pair.inputs.size(); i++) {
inputs_ids[i] = tflite_node_id_pair.inputs[i];
}
for (i = 0; i < tflite_node_id_pair.outputs.size(); i++) {
outputs_ids[i] = tflite_node_id_pair.outputs[i];
}
tflite_node_builtin_code = tflite_node_id_pair.builtin_code;
tflite_node_uid["inputs"] = inputs_ids;
tflite_node_uid["outputs"] = outputs_ids;
tflite_node_uid["builtin_code"] = tflite_node_id_pair.builtin_code;

for (i = 0; i < tflite_node_id_pair.op_uids.size(); i++) {
tim_vx_uids[i] = tflite_node_id_pair.op_uids[i];
}

map_pair["tflite_node_id"] = tflite_node_uid;
map_pair["tim_vx_uid"] = tim_vx_uids;
root.append(map_pair);
}

fs << sw.write(root);
fs.close();
return;
}
#endif

} // namespace utils
} // namespace delegate
} // namespace vx
8 changes: 8 additions & 0 deletions utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,14 @@ void GenerateWeightsDataForBilinear(float* data,
void GenerateWeightDataForNearest(float* data,
const std::vector<uint32_t>& weight_shape);

#ifdef NODE_TRACE_DB_MODE
void MapTfliteNodeToTimVxNode(const std::vector<std::shared_ptr<tim::vx::Operation>>& before_op_vector,
const std::vector<std::shared_ptr<tim::vx::Operation>>& after_op_vector,
std::vector<vx::delegate::TfliteNodeIDPair>& tflite_node_id_map);

void GenerateVxNodeTraceDb(std::vector<vx::delegate::TfliteNodeIDPair>& tflite_node_id_map);
#endif

template <typename T>
inline void Quantize(const std::vector<float>& data, float scale,
int32_t zero_point, std::vector<T>& quant_data) {
Expand Down
Loading