Skip to content

Commit

Permalink
Upgrade TensorFlowLite
Browse files Browse the repository at this point in the history
  • Loading branch information
pschatzmann committed Nov 3, 2024
1 parent 1464fc0 commit fc3f009
Showing 1 changed file with 7 additions and 29 deletions.
36 changes: 7 additions & 29 deletions src/AudioTools/AudioLibs/TfLiteAudioStream.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// Configure FFT to output 16 bit fixed point.
#define FIXED_POINT 16

//#include <MicroTFLite.h>
#include <TensorFlowLite.h>
#include <cmath>
#include <cstdint>
Expand All @@ -13,7 +14,6 @@
#include "tensorflow/lite/experimental/microfrontend/lib/frontend_util.h"
#include "tensorflow/lite/micro/all_ops_resolver.h"
#include "tensorflow/lite/micro/kernels/micro_ops.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
#include "tensorflow/lite/micro/system_setup.h"
Expand Down Expand Up @@ -55,25 +55,6 @@ class TfLiteWriter {
virtual bool begin(TfLiteAudioStreamBase *parent) = 0;
virtual bool write(const int16_t sample) = 0;
};
/**
* @brief Error Reporter using the Audio Tools Logger
* @ingroup tflite
* @author Phil Schatzmann
* @copyright GPLv3
*/
class TfLiteAudioErrorReporter : public tflite::ErrorReporter {
public:
virtual ~TfLiteAudioErrorReporter() {}
virtual int Report(const char* format, va_list args) override {
int result = snprintf(msg, 200, format, args);
LOGE(msg);
return result;
}

protected:
char msg[200];
} my_error_reporter;
tflite::ErrorReporter* error_reporter = &my_error_reporter;

/**
* @brief Configuration settings for TfLiteAudioStream
Expand All @@ -96,7 +77,7 @@ struct TfLiteConfig {
// Create an area of memory to use for input, output, and intermediate arrays.
// The size of this will depend on the model you’re using, and may need to be
// determined by experimentation.
int kTensorArenaSize = 10 * 1024;
size_t kTensorArenaSize = 10 * 1024;

// Keeping these as constant expressions allow us to allocate fixed-sized
// arrays on the stack for our working memory.
Expand Down Expand Up @@ -980,14 +961,12 @@ class TfLiteAudioStream : public TfLiteAudioStreamBase {
TRACEI();
if (cfg.useAllOpsResolver) {
tflite::AllOpsResolver resolver;
static tflite::MicroInterpreter static_interpreter(
p_model, resolver, p_tensor_arena, cfg.kTensorArenaSize,
error_reporter);
static tflite::MicroInterpreter static_interpreter{
p_model, resolver, p_tensor_arena, cfg.kTensorArenaSize};
p_interpreter = &static_interpreter;
} else {
// NOLINTNEXTLINE(runtime-global-variables)
static tflite::MicroMutableOpResolver<4> micro_op_resolver(
error_reporter);
static tflite::MicroMutableOpResolver<4> micro_op_resolver{};
if (micro_op_resolver.AddDepthwiseConv2D() != kTfLiteOk) {
return false;
}
Expand All @@ -1001,9 +980,8 @@ class TfLiteAudioStream : public TfLiteAudioStreamBase {
return false;
}
// Build an p_interpreter to run the model with.
static tflite::MicroInterpreter static_interpreter(
p_model, micro_op_resolver, p_tensor_arena, cfg.kTensorArenaSize,
error_reporter);
static tflite::MicroInterpreter static_interpreter{
p_model, micro_op_resolver, p_tensor_arena, cfg.kTensorArenaSize};
p_interpreter = &static_interpreter;
}
}
Expand Down

0 comments on commit fc3f009

Please sign in to comment.