Skip to content

Commit

Permalink
Support generate node trace db from vx_delegate to tim-vx
Browse files Browse the repository at this point in the history
Type: Code Improvement
  • Loading branch information
zhengzhouheng committed Dec 19, 2023
1 parent d8552f8 commit e3529cd
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 0 deletions.
28 changes: 28 additions & 0 deletions delegate_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,9 @@ TfLiteStatus Delegate::Invoke(const OpData& op_data,
TfLiteContext* context,
TfLiteNode* node) {
TFLITE_LOG(TFLITE_LOG_INFO, "Delegate::Invoke node: %p", node->user_data);
char* node_trace_enable = getenv("VIV_VX_DUMP_NODE_TRACE_DB");
std::vector<vx::delegate::TfliteNodeIDPair> tflite_node_id_map;

if (!compiled_) {
// TODO(bo): Handling multi-thread use case
context_ = tim::vx::Context::Create();
Expand Down Expand Up @@ -562,12 +565,28 @@ TfLiteStatus Delegate::Invoke(const OpData& op_data,
auto& states = op_info.states;
auto& builtin_data = op_info.builtin_data;

vx::delegate::TfliteNodeIDPair tflite_node_id_pair;
std::vector<std::shared_ptr<tim::vx::Operation>> before_op_vector;
std::vector<std::shared_ptr<tim::vx::Operation>> after_op_vector;

std::vector<int> inputs_outputs;
std::copy(
inputs.begin(), inputs.end(), std::back_inserter(inputs_outputs));
std::copy(
outputs.begin(), outputs.end(), std::back_inserter(inputs_outputs));

if(node_trace_enable && *node_trace_enable != 0)
{
tflite_node_id_pair.builtin_code = builtin_code;
std::copy(
inputs.begin(), inputs.end(), std::back_inserter(tflite_node_id_pair.inputs));
std::copy(
outputs.begin(), outputs.end(), std::back_inserter(tflite_node_id_pair.outputs));
std::copy(
this->GetGraph()->OpVector().begin(), this->GetGraph()->OpVector().end(), std::back_inserter(before_op_vector));
tflite_node_id_map.push_back(tflite_node_id_pair);
}

for (size_t port_idx = 0; port_idx < inputs_outputs.size(); port_idx++) {
int tensor_idx = inputs_outputs[port_idx];
if (-1 != tensor_idx && tensors_.find(tensor_idx) == tensors_.end()) {
Expand Down Expand Up @@ -634,6 +653,15 @@ TfLiteStatus Delegate::Invoke(const OpData& op_data,
states_tensors,
builtin_data.data());
}
if(node_trace_enable && *node_trace_enable != 0){
std::copy(
this->GetGraph()->OpVector().begin(), this->GetGraph()->OpVector().end(), std::back_inserter(after_op_vector));
vx::delegate::utils::MapTfliteNodeToTimVxNode(before_op_vector, after_op_vector, tflite_node_id_map);
}
}
if(node_trace_enable && *node_trace_enable != 0)
{
vx::delegate::utils::GenerateVxNodeTraceDb(tflite_node_id_map);
}

TFLITE_LOG(TFLITE_LOG_INFO, "Verifying graph");
Expand Down
9 changes: 9 additions & 0 deletions delegate_main.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,15 @@ typedef struct {
bool error_during_invoke;
} VxDelegateOptions;

typedef struct
{
//tflite node unique id
std::vector<int> inputs;
std::vector<int> outputs;
int builtin_code;
//tim wx node uid
std::vector<uint32_t> op_uids;
}TfliteNodeIDPair;

class Delegate;

Expand Down
47 changes: 47 additions & 0 deletions utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,53 @@ void GenerateWeightDataForNearest(float* data,
return;
}

void MapTfliteNodeToTimVxNode(const std::vector<std::shared_ptr<tim::vx::Operation>>& before_op_vector,
const std::vector<std::shared_ptr<tim::vx::Operation>>& after_op_vector,
std::vector<vx::delegate::TfliteNodeIDPair>& tflite_node_id_map){
size_t new_operation_size = after_op_vector.size() - before_op_vector.size();
size_t i = 0;
std::vector<uint32_t> new_operation;
if(new_operation_size <= 0 || tflite_node_id_map.size () == 0)
{
return;
}

for(i = 0;i<new_operation_size;i++)
{
size_t new_operation_index = before_op_vector.size();
uint32_t uid = after_op_vector[new_operation_index + i]->uid();
tflite_node_id_map[tflite_node_id_map.size() -1].op_uids.push_back(uid);
}
return;
}

void GenerateVxNodeTraceDb(std::vector<vx::delegate::TfliteNodeIDPair>& tflite_node_id_map)
{
std::fstream f;
f.open("vx_node_trace_db", std::ios::out | std::ios::trunc);
for(auto tflite_node_id_pair : tflite_node_id_map)
{
f<< tflite_node_id_pair.builtin_code
<<" "<<tflite_node_id_pair.inputs.size()
<<" "<<tflite_node_id_pair.outputs.size()
<<" "<<tflite_node_id_pair.op_uids.size();
for(auto input : tflite_node_id_pair.inputs)
{
f<<" "<<input;
}
for(auto output : tflite_node_id_pair.outputs)
{
f<<" "<<output;
}
for(auto op_uid : tflite_node_id_pair.op_uids)
{
f<<" "<<op_uid;
}
f<<std::endl;
}
f.close();
return;
}
} // namespace utils
} // namespace delegate
} // namespace vx
6 changes: 6 additions & 0 deletions utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ void GenerateWeightsDataForBilinear(float* data,
void GenerateWeightDataForNearest(float* data,
const std::vector<uint32_t>& weight_shape);

void MapTfliteNodeToTimVxNode(const std::vector<std::shared_ptr<tim::vx::Operation>>& before_op_vector,
const std::vector<std::shared_ptr<tim::vx::Operation>>& after_op_vector,
std::vector<vx::delegate::TfliteNodeIDPair>& tflite_node_id_map);

void GenerateVxNodeTraceDb(std::vector<vx::delegate::TfliteNodeIDPair>& tflite_node_id_map);

template <typename T>
inline void Quantize(const std::vector<float>& data, float scale,
int32_t zero_point, std::vector<T>& quant_data) {
Expand Down

0 comments on commit e3529cd

Please sign in to comment.