diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc index 21bcff40bda553..cc448b03c3d85c 100644 --- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc +++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc @@ -334,6 +334,19 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createErrorReporter( return reinterpret_cast(error_reporter); } +// Verifies whether the model is a flatbuffer file. +class JNIFlatBufferVerifier : public tflite::TfLiteVerifier { + public: + bool Verify(const char* data, int length, + tflite::ErrorReporter* reporter) override { + if (!VerifyModel(data, length)) { + reporter->Report("The model is not a valid Flatbuffer file"); + return false; + } + return true; + } +}; + JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_createModel( JNIEnv* env, jclass clazz, jstring model_file, jlong error_handle) { @@ -342,17 +355,11 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createModel( if (error_reporter == nullptr) return 0; const char* path = env->GetStringUTFChars(model_file, nullptr); - { - tflite::FileCopyAllocation allocation(path, nullptr); - if (!VerifyModel(allocation.base(), allocation.bytes())) { - throwException(env, kIllegalArgumentException, - "Contents of %s is not a valid flatbuffer model", path); - env->ReleaseStringUTFChars(model_file, path); - return 0; - } - } + std::unique_ptr verifier; + verifier.reset(new JNIFlatBufferVerifier()); - auto model = tflite::FlatBufferModel::BuildFromFile(path, error_reporter); + auto model = tflite::FlatBufferModel::VerifyAndBuildFromFile( + path, verifier.get(), error_reporter); if (!model) { throwException(env, kIllegalArgumentException, "Contents of %s does not encode a valid TensorFlowLite " diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java index d6b4e9f438ab19..dbe45e5a05b822 100644 --- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java @@ -47,6 +47,9 @@ public final class NativeInterpreterWrapperTest { private static final String MODEL_WITH_CUSTOM_OP_PATH = "tensorflow/contrib/lite/java/src/testdata/with_custom_op.lite"; + private static final String NONEXISTING_MODEL_PATH = + "tensorflow/contrib/lite/java/src/testdata/nonexisting_model.bin"; + @Test public void testConstructor() { NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); @@ -60,7 +63,18 @@ public void testConstructorWithInvalidModel() { NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(INVALID_MODEL_PATH); fail(); } catch (IllegalArgumentException e) { - assertThat(e).hasMessageThat().contains("is not a valid flatbuffer model"); + assertThat(e).hasMessageThat().contains("The model is not a valid Flatbuffer file"); + } + } + + @Test + public void testConstructorWithNonexistingModel() { + try { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(NONEXISTING_MODEL_PATH); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().contains("The model is not a valid Flatbuffer file"); + assertThat(e).hasMessageThat().contains("Could not open"); } } diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index 3cf6bcbfcd6e01..f28d56af674982 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -32,11 +32,46 @@ namespace tflite { const char* kEmptyTensorName = ""; +// Loads a model from `filename`. If `mmap_file` is true then use mmap, +// otherwise make a copy of the model in a buffer. +std::unique_ptr GetAllocationFromFile(const char* filename, + bool mmap_file, + ErrorReporter* error_reporter, + bool use_nnapi) { + std::unique_ptr allocation; + if (mmap_file) { + if (use_nnapi && NNAPIExists()) + allocation.reset(new NNAPIAllocation(filename, error_reporter)); + else + allocation.reset(new MMAPAllocation(filename, error_reporter)); + } else { + allocation.reset(new FileCopyAllocation(filename, error_reporter)); + } + return allocation; +} + std::unique_ptr FlatBufferModel::BuildFromFile( const char* filename, ErrorReporter* error_reporter) { std::unique_ptr model; - model.reset(new FlatBufferModel(filename, /*mmap_file=*/true, error_reporter, - /*use_nnapi=*/true)); + auto allocation = GetAllocationFromFile(filename, /*mmap_file=*/true, + error_reporter, /*use_nnapi=*/true); + model.reset(new FlatBufferModel(allocation.release(), error_reporter)); + if (!model->initialized()) model.reset(); + return model; +} + +std::unique_ptr FlatBufferModel::VerifyAndBuildFromFile( + const char* filename, TfLiteVerifier* verifier, + ErrorReporter* error_reporter) { + std::unique_ptr model; + auto allocation = GetAllocationFromFile(filename, /*mmap_file=*/true, + error_reporter, /*use_nnapi=*/true); + if (verifier && + !verifier->Verify(static_cast(allocation->base()), + allocation->bytes(), error_reporter)) { + return model; + } + model.reset(new FlatBufferModel(allocation.release(), error_reporter)); if (!model->initialized()) model.reset(); return model; } @@ -44,7 +79,9 @@ std::unique_ptr FlatBufferModel::BuildFromFile( std::unique_ptr FlatBufferModel::BuildFromBuffer( const char* buffer, size_t buffer_size, ErrorReporter* error_reporter) { std::unique_ptr model; - model.reset(new FlatBufferModel(buffer, buffer_size, error_reporter)); + Allocation* allocation = + new MemoryAllocation(buffer, buffer_size, error_reporter); + model.reset(new FlatBufferModel(allocation, error_reporter)); if (!model->initialized()) model.reset(); return model; } @@ -57,23 +94,6 @@ std::unique_ptr FlatBufferModel::BuildFromModel( return model; } -FlatBufferModel::FlatBufferModel(const char* filename, bool mmap_file, - ErrorReporter* error_reporter, bool use_nnapi) - : error_reporter_(error_reporter ? error_reporter - : DefaultErrorReporter()) { - if (mmap_file) { - if (use_nnapi && NNAPIExists()) - allocation_ = new NNAPIAllocation(filename, error_reporter); - else - allocation_ = new MMAPAllocation(filename, error_reporter); - } else { - allocation_ = new FileCopyAllocation(filename, error_reporter); - } - if (!allocation_->valid() || !CheckModelIdentifier()) return; - - model_ = ::tflite::GetModel(allocation_->base()); -} - bool FlatBufferModel::CheckModelIdentifier() const { if (!tflite::ModelBufferHasIdentifier(allocation_->base())) { const char* ident = flatbuffers::GetBufferIdentifier(allocation_->base()); @@ -85,21 +105,21 @@ bool FlatBufferModel::CheckModelIdentifier() const { return true; } -FlatBufferModel::FlatBufferModel(const char* ptr, size_t num_bytes, +FlatBufferModel::FlatBufferModel(const Model* model, ErrorReporter* error_reporter) : error_reporter_(error_reporter ? error_reporter : DefaultErrorReporter()) { - allocation_ = new MemoryAllocation(ptr, num_bytes, error_reporter); - if (!allocation_->valid()) return; - - model_ = ::tflite::GetModel(allocation_->base()); + model_ = model; } -FlatBufferModel::FlatBufferModel(const Model* model, +FlatBufferModel::FlatBufferModel(Allocation* allocation, ErrorReporter* error_reporter) : error_reporter_(error_reporter ? error_reporter : DefaultErrorReporter()) { - model_ = model; + allocation_ = allocation; + if (!allocation_->valid() || !CheckModelIdentifier()) return; + + model_ = ::tflite::GetModel(allocation_->base()); } FlatBufferModel::~FlatBufferModel() { delete allocation_; } diff --git a/tensorflow/contrib/lite/model.h b/tensorflow/contrib/lite/model.h index 8dc1c794dce6f9..38eea0e26bc97c 100644 --- a/tensorflow/contrib/lite/model.h +++ b/tensorflow/contrib/lite/model.h @@ -41,6 +41,17 @@ limitations under the License. namespace tflite { +// Abstract interface that verifies whether a given model is legit. +// It facilitates the use-case to verify and build a model without loading it +// twice. +class TfLiteVerifier { + public: + // Returns true if the model is legit. + virtual bool Verify(const char* data, int length, + ErrorReporter* reporter) = 0; + virtual ~TfLiteVerifier() {} +}; + // An RAII object that represents a read-only tflite model, copied from disk, // or mmapped. This uses flatbuffers as the serialization format. class FlatBufferModel { @@ -50,6 +61,12 @@ class FlatBufferModel { const char* filename, ErrorReporter* error_reporter = DefaultErrorReporter()); + // Verifies whether the content of the file is legit, then builds a model + // based on the file. Returns a nullptr in case of failure. + static std::unique_ptr VerifyAndBuildFromFile( + const char* filename, TfLiteVerifier* verifier = nullptr, + ErrorReporter* error_reporter = DefaultErrorReporter()); + // Builds a model based on a pre-loaded flatbuffer. The caller retains // ownership of the buffer and should keep it alive until the returned object // is destroyed. Returns a nullptr in case of failure. @@ -82,23 +99,9 @@ class FlatBufferModel { bool CheckModelIdentifier() const; private: - // Loads a model from `filename`. If `mmap_file` is true then use mmap, - // otherwise make a copy of the model in a buffer. - // - // Note, if `error_reporter` is null, then a DefaultErrorReporter() will be - // used. - explicit FlatBufferModel( - const char* filename, bool mmap_file = true, - ErrorReporter* error_reporter = DefaultErrorReporter(), - bool use_nnapi = false); - - // Loads a model from `ptr` and `num_bytes` of the model file. The `ptr` has - // to remain alive and unchanged until the end of this flatbuffermodel's - // lifetime. - // - // Note, if `error_reporter` is null, then a DefaultErrorReporter() will be - // used. - FlatBufferModel(const char* ptr, size_t num_bytes, + // Loads a model from a given allocation. FlatBufferModel will take over the + // ownership of `allocation`, and delete it in desctructor. + FlatBufferModel(Allocation* allocation, ErrorReporter* error_reporter = DefaultErrorReporter()); // Loads a model from Model flatbuffer. The `model` has to remain alive and diff --git a/tensorflow/contrib/lite/model_test.cc b/tensorflow/contrib/lite/model_test.cc index 66f22fd66a9ae0..ae6c1ece18963f 100644 --- a/tensorflow/contrib/lite/model_test.cc +++ b/tensorflow/contrib/lite/model_test.cc @@ -209,6 +209,38 @@ TEST(BasicFlatBufferModel, TestNullModel) { ASSERT_EQ(interpreter.get(), nullptr); } +// Mocks the verifier by setting the result in ctor. +class FakeVerifier : public tflite::TfLiteVerifier { + public: + explicit FakeVerifier(bool result) : result_(result) {} + bool Verify(const char* data, int length, + tflite::ErrorReporter* reporter) override { + return result_; + } + + private: + bool result_; +}; + +TEST(BasicFlatBufferModel, TestWithTrueVerifier) { + FakeVerifier verifier(true); + ASSERT_TRUE(FlatBufferModel::VerifyAndBuildFromFile( + "tensorflow/contrib/lite/testdata/test_model.bin", + &verifier)); +} + +TEST(BasicFlatBufferModel, TestWithFalseVerifier) { + FakeVerifier verifier(false); + ASSERT_FALSE(FlatBufferModel::VerifyAndBuildFromFile( + "tensorflow/contrib/lite/testdata/test_model.bin", + &verifier)); +} + +TEST(BasicFlatBufferModel, TestWithNullVerifier) { + ASSERT_TRUE(FlatBufferModel::VerifyAndBuildFromFile( + "tensorflow/contrib/lite/testdata/test_model.bin", nullptr)); +} + struct TestErrorReporter : public ErrorReporter { int Report(const char* format, va_list args) override { calls++;