Skip to content

Commit

Permalink
Pass error reporter to file copy allocation,
Browse files Browse the repository at this point in the history
and avoid loading model from file twice

PiperOrigin-RevId: 189256489
  • Loading branch information
tensorflower-gardener committed Mar 15, 2018
1 parent 239eb8b commit 6c62e65
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,19 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createErrorReporter(
return reinterpret_cast<jlong>(error_reporter);
}

// Verifies whether the model is a flatbuffer file.
class JNIFlatBufferVerifier : public tflite::TfLiteVerifier {
public:
bool Verify(const char* data, int length,
tflite::ErrorReporter* reporter) override {
if (!VerifyModel(data, length)) {
reporter->Report("The model is not a valid Flatbuffer file");
return false;
}
return true;
}
};

JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_createModel(
JNIEnv* env, jclass clazz, jstring model_file, jlong error_handle) {
Expand All @@ -342,17 +355,11 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createModel(
if (error_reporter == nullptr) return 0;
const char* path = env->GetStringUTFChars(model_file, nullptr);

{
tflite::FileCopyAllocation allocation(path, nullptr);
if (!VerifyModel(allocation.base(), allocation.bytes())) {
throwException(env, kIllegalArgumentException,
"Contents of %s is not a valid flatbuffer model", path);
env->ReleaseStringUTFChars(model_file, path);
return 0;
}
}
std::unique_ptr<tflite::TfLiteVerifier> verifier;
verifier.reset(new JNIFlatBufferVerifier());

auto model = tflite::FlatBufferModel::BuildFromFile(path, error_reporter);
auto model = tflite::FlatBufferModel::VerifyAndBuildFromFile(
path, verifier.get(), error_reporter);
if (!model) {
throwException(env, kIllegalArgumentException,
"Contents of %s does not encode a valid TensorFlowLite "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ public final class NativeInterpreterWrapperTest {
private static final String MODEL_WITH_CUSTOM_OP_PATH =
"tensorflow/contrib/lite/java/src/testdata/with_custom_op.lite";

private static final String NONEXISTING_MODEL_PATH =
"tensorflow/contrib/lite/java/src/testdata/nonexisting_model.bin";

@Test
public void testConstructor() {
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
Expand All @@ -60,7 +63,18 @@ public void testConstructorWithInvalidModel() {
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(INVALID_MODEL_PATH);
fail();
} catch (IllegalArgumentException e) {
assertThat(e).hasMessageThat().contains("is not a valid flatbuffer model");
assertThat(e).hasMessageThat().contains("The model is not a valid Flatbuffer file");
}
}

@Test
public void testConstructorWithNonexistingModel() {
try {
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(NONEXISTING_MODEL_PATH);
fail();
} catch (IllegalArgumentException e) {
assertThat(e).hasMessageThat().contains("The model is not a valid Flatbuffer file");
assertThat(e).hasMessageThat().contains("Could not open");
}
}

Expand Down
74 changes: 47 additions & 27 deletions tensorflow/contrib/lite/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,56 @@ namespace tflite {

const char* kEmptyTensorName = "";

// Loads a model from `filename`. If `mmap_file` is true then use mmap,
// otherwise make a copy of the model in a buffer.
std::unique_ptr<Allocation> GetAllocationFromFile(const char* filename,
bool mmap_file,
ErrorReporter* error_reporter,
bool use_nnapi) {
std::unique_ptr<Allocation> allocation;
if (mmap_file) {
if (use_nnapi && NNAPIExists())
allocation.reset(new NNAPIAllocation(filename, error_reporter));
else
allocation.reset(new MMAPAllocation(filename, error_reporter));
} else {
allocation.reset(new FileCopyAllocation(filename, error_reporter));
}
return allocation;
}

std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromFile(
const char* filename, ErrorReporter* error_reporter) {
std::unique_ptr<FlatBufferModel> model;
model.reset(new FlatBufferModel(filename, /*mmap_file=*/true, error_reporter,
/*use_nnapi=*/true));
auto allocation = GetAllocationFromFile(filename, /*mmap_file=*/true,
error_reporter, /*use_nnapi=*/true);
model.reset(new FlatBufferModel(allocation.release(), error_reporter));
if (!model->initialized()) model.reset();
return model;
}

std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromFile(
const char* filename, TfLiteVerifier* verifier,
ErrorReporter* error_reporter) {
std::unique_ptr<FlatBufferModel> model;
auto allocation = GetAllocationFromFile(filename, /*mmap_file=*/true,
error_reporter, /*use_nnapi=*/true);
if (verifier &&
!verifier->Verify(static_cast<const char*>(allocation->base()),
allocation->bytes(), error_reporter)) {
return model;
}
model.reset(new FlatBufferModel(allocation.release(), error_reporter));
if (!model->initialized()) model.reset();
return model;
}

std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromBuffer(
const char* buffer, size_t buffer_size, ErrorReporter* error_reporter) {
std::unique_ptr<FlatBufferModel> model;
model.reset(new FlatBufferModel(buffer, buffer_size, error_reporter));
Allocation* allocation =
new MemoryAllocation(buffer, buffer_size, error_reporter);
model.reset(new FlatBufferModel(allocation, error_reporter));
if (!model->initialized()) model.reset();
return model;
}
Expand All @@ -57,23 +94,6 @@ std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromModel(
return model;
}

FlatBufferModel::FlatBufferModel(const char* filename, bool mmap_file,
ErrorReporter* error_reporter, bool use_nnapi)
: error_reporter_(error_reporter ? error_reporter
: DefaultErrorReporter()) {
if (mmap_file) {
if (use_nnapi && NNAPIExists())
allocation_ = new NNAPIAllocation(filename, error_reporter);
else
allocation_ = new MMAPAllocation(filename, error_reporter);
} else {
allocation_ = new FileCopyAllocation(filename, error_reporter);
}
if (!allocation_->valid() || !CheckModelIdentifier()) return;

model_ = ::tflite::GetModel(allocation_->base());
}

bool FlatBufferModel::CheckModelIdentifier() const {
if (!tflite::ModelBufferHasIdentifier(allocation_->base())) {
const char* ident = flatbuffers::GetBufferIdentifier(allocation_->base());
Expand All @@ -85,21 +105,21 @@ bool FlatBufferModel::CheckModelIdentifier() const {
return true;
}

FlatBufferModel::FlatBufferModel(const char* ptr, size_t num_bytes,
FlatBufferModel::FlatBufferModel(const Model* model,
ErrorReporter* error_reporter)
: error_reporter_(error_reporter ? error_reporter
: DefaultErrorReporter()) {
allocation_ = new MemoryAllocation(ptr, num_bytes, error_reporter);
if (!allocation_->valid()) return;

model_ = ::tflite::GetModel(allocation_->base());
model_ = model;
}

FlatBufferModel::FlatBufferModel(const Model* model,
FlatBufferModel::FlatBufferModel(Allocation* allocation,
ErrorReporter* error_reporter)
: error_reporter_(error_reporter ? error_reporter
: DefaultErrorReporter()) {
model_ = model;
allocation_ = allocation;
if (!allocation_->valid() || !CheckModelIdentifier()) return;

model_ = ::tflite::GetModel(allocation_->base());
}

FlatBufferModel::~FlatBufferModel() { delete allocation_; }
Expand Down
37 changes: 20 additions & 17 deletions tensorflow/contrib/lite/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,17 @@ limitations under the License.

namespace tflite {

// Abstract interface that verifies whether a given model is legit.
// It facilitates the use-case to verify and build a model without loading it
// twice.
class TfLiteVerifier {
public:
// Returns true if the model is legit.
virtual bool Verify(const char* data, int length,
ErrorReporter* reporter) = 0;
virtual ~TfLiteVerifier() {}
};

// An RAII object that represents a read-only tflite model, copied from disk,
// or mmapped. This uses flatbuffers as the serialization format.
class FlatBufferModel {
Expand All @@ -50,6 +61,12 @@ class FlatBufferModel {
const char* filename,
ErrorReporter* error_reporter = DefaultErrorReporter());

// Verifies whether the content of the file is legit, then builds a model
// based on the file. Returns a nullptr in case of failure.
static std::unique_ptr<FlatBufferModel> VerifyAndBuildFromFile(
const char* filename, TfLiteVerifier* verifier = nullptr,
ErrorReporter* error_reporter = DefaultErrorReporter());

// Builds a model based on a pre-loaded flatbuffer. The caller retains
// ownership of the buffer and should keep it alive until the returned object
// is destroyed. Returns a nullptr in case of failure.
Expand Down Expand Up @@ -82,23 +99,9 @@ class FlatBufferModel {
bool CheckModelIdentifier() const;

private:
// Loads a model from `filename`. If `mmap_file` is true then use mmap,
// otherwise make a copy of the model in a buffer.
//
// Note, if `error_reporter` is null, then a DefaultErrorReporter() will be
// used.
explicit FlatBufferModel(
const char* filename, bool mmap_file = true,
ErrorReporter* error_reporter = DefaultErrorReporter(),
bool use_nnapi = false);

// Loads a model from `ptr` and `num_bytes` of the model file. The `ptr` has
// to remain alive and unchanged until the end of this flatbuffermodel's
// lifetime.
//
// Note, if `error_reporter` is null, then a DefaultErrorReporter() will be
// used.
FlatBufferModel(const char* ptr, size_t num_bytes,
// Loads a model from a given allocation. FlatBufferModel will take over the
// ownership of `allocation`, and delete it in desctructor.
FlatBufferModel(Allocation* allocation,
ErrorReporter* error_reporter = DefaultErrorReporter());

// Loads a model from Model flatbuffer. The `model` has to remain alive and
Expand Down
32 changes: 32 additions & 0 deletions tensorflow/contrib/lite/model_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,38 @@ TEST(BasicFlatBufferModel, TestNullModel) {
ASSERT_EQ(interpreter.get(), nullptr);
}

// Mocks the verifier by setting the result in ctor.
class FakeVerifier : public tflite::TfLiteVerifier {
public:
explicit FakeVerifier(bool result) : result_(result) {}
bool Verify(const char* data, int length,
tflite::ErrorReporter* reporter) override {
return result_;
}

private:
bool result_;
};

TEST(BasicFlatBufferModel, TestWithTrueVerifier) {
FakeVerifier verifier(true);
ASSERT_TRUE(FlatBufferModel::VerifyAndBuildFromFile(
"tensorflow/contrib/lite/testdata/test_model.bin",
&verifier));
}

TEST(BasicFlatBufferModel, TestWithFalseVerifier) {
FakeVerifier verifier(false);
ASSERT_FALSE(FlatBufferModel::VerifyAndBuildFromFile(
"tensorflow/contrib/lite/testdata/test_model.bin",
&verifier));
}

TEST(BasicFlatBufferModel, TestWithNullVerifier) {
ASSERT_TRUE(FlatBufferModel::VerifyAndBuildFromFile(
"tensorflow/contrib/lite/testdata/test_model.bin", nullptr));
}

struct TestErrorReporter : public ErrorReporter {
int Report(const char* format, va_list args) override {
calls++;
Expand Down

0 comments on commit 6c62e65

Please sign in to comment.