Skip to content

Commit

Permalink
Update tensorflow_models usage of tf.lite.interpreter to run ai-edge-…
Browse files Browse the repository at this point in the history
…litert.interpreter

PiperOrigin-RevId: 682467705
  • Loading branch information
ecalubaquib authored and copybara-github committed Oct 18, 2024
1 parent 8f60a93 commit a278737
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ OpenEXR >= 1.3.2
termcolor >= 1.1.0
trimesh >= 2.37.22
# Required by trimesh.
networkx
networkx
7 changes: 6 additions & 1 deletion tensorflow_graphics/util/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@
import tensorflow as tf

from tensorflow_graphics.util import tfg_flags
# pylint: disable=g-direct-tensorflow-import
from ai-edge-litert import interpreter as tfl_interpreter
# pylint: enable=g-direct-tensorflow-import


FLAGS = flags.FLAGS

Expand Down Expand Up @@ -98,6 +102,7 @@ def _compute_gradient_error(self, x, y, x_init, delta=1e-6):
error = 0
row_max_error = 0
column_max_error = 0
max_error = 0
for j_t, j_n in grad:
if j_t.size or j_n.size: # Handle zero size tensors correctly
diff = np.fabs(j_t - j_n)
Expand Down Expand Up @@ -364,7 +369,7 @@ def assert_tf_lite_convertible(self,
sess, in_tensors, out_tensors)
tflite_model = converter.convert()
# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter = tfl_interpreter.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
# If no test inputs provided then randomly generate inputs.
if test_inputs is None:
Expand Down

0 comments on commit a278737

Please sign in to comment.