Skip to content

Commit

Permalink
tensorflow-lite: pipeline pre/post processing
Browse files Browse the repository at this point in the history
  • Loading branch information
koush committed Dec 28, 2024
1 parent 586f78e commit 6438ad1
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 28 deletions.
4 changes: 2 additions & 2 deletions plugins/tensorflow-lite/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion plugins/tensorflow-lite/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -58,5 +58,5 @@
"devDependencies": {
"@scrypted/sdk": "file:../../sdk"
},
"version": "0.1.73"
"version": "0.1.74"
}
69 changes: 44 additions & 25 deletions plugins/tensorflow-lite/src/tflite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
from common import yolo
from predict import PredictPlugin

prepareExecutor = concurrent.futures.ThreadPoolExecutor(thread_name_prefix="TFLite-Prepare")

availableModels = [
"Default",
"scrypted_yolov9s_relu_sep_320",
Expand Down Expand Up @@ -148,7 +150,8 @@ def downloadModel():
try:
interpreter = make_interpreter(modelFile, ":%s" % idx)
interpreter.allocate_tensors()
_, height, width, channels = interpreter.get_input_details()[0][
self.image_input_details = interpreter.get_input_details()[0]
_, height, width, channels = self.image_input_details[
"shape"
]
self.input_details = int(width), int(height), int(channels)
Expand All @@ -170,7 +173,8 @@ def downloadModel():
modelFile = downloadModel()
interpreter = tflite.Interpreter(model_path=modelFile)
interpreter.allocate_tensors()
_, height, width, channels = interpreter.get_input_details()[0]["shape"]
self.image_input_details = interpreter.get_input_details()[0]
_, height, width, channels = self.image_input_details["shape"]
self.input_details = int(width), int(height), int(channels)
available_interpreters.append(interpreter)
self.interpreter_count = self.interpreter_count + 1
Expand Down Expand Up @@ -221,41 +225,54 @@ def get_input_size(self) -> Tuple[int, int]:
return self.input_details[0:2]

async def detect_once(self, input: Image.Image, settings: Any, src_size, cvss):
def predict():
def prepare():
if not self.yolo:
return input

im = np.stack([input])
# this non-quantized code path is unused but here for reference.
if self.image_input_details["dtype"] != np.int8 and self.image_input_details["dtype"] != np.int16:
im = im.astype(np.float32) / 255.0
return im

scale, zero_point = self.image_input_details["quantization"]
if scale == 0.003986024297773838 and zero_point == -128:
# fast path for quantization 1/255 = 0.003986024297773838
im = im.view(np.int8)
im -= 128
else:
im = im.astype(np.float32) / (255.0 * scale)
im = (im + zero_point).astype(np.int8) # de-scale

return im

def predict(im):
interpreter = self.interpreters[threading.current_thread().name]
if not self.yolo:
tflite_common.set_input(interpreter, input)
tflite_common.set_input(interpreter, im)
interpreter.invoke()
objs = detect.get_objects(
interpreter, score_threshold=0.2, image_scale=(1, 1)
)
return objs

tensor_index = input_details(interpreter, "index")

im = np.stack([input])
i = interpreter.get_input_details()[0]
if i["dtype"] == np.int8:
scale, zero_point = i["quantization"]
if scale == 0.003986024297773838 and zero_point == -128:
# fast path for quantization 1/255 = 0.003986024297773838
im = im.view(np.int8)
im -= 128
else:
im = im.astype(np.float32) / (255.0 * scale)
im = (im + zero_point).astype(np.int8) # de-scale
else:
# this code path is unused.
im = im.astype(np.float32) / 255.0
interpreter.set_tensor(tensor_index, im)
interpreter.invoke()
output_details = interpreter.get_output_details()
output_tensors = [(interpreter.get_tensor(output["index"]), output) for output in output_details]

# handle sseparate outputs for quantization accuracy
return output_tensors

def post_process(output_tensors):
if not self.yolo:
return output_tensors

# handle separate outputs for quantization accuracy
if self.scrypted_yolo_sep:
outputs = []
for output in output_details:
o = interpreter.get_tensor(output["index"]).astype(np.float32)
for ot, output in output_tensors:
o = ot.astype(np.float32)
scale, zero_point = output["quantization"]
o -= zero_point
o *= scale
Expand All @@ -269,8 +286,7 @@ def predict():
return objs

# this scale stuff can probably be optimized to dequantize ahead of time...
output = output_details[0]
x = interpreter.get_tensor(output["index"])
x, output = output_tensors[0]
input_scale = self.get_input_details()[0]

# this non-quantized code path is unused but here for reference.
Expand Down Expand Up @@ -300,7 +316,10 @@ def predict():
)
return objs

objs = await asyncio.get_event_loop().run_in_executor(self.executor, predict)

im = await asyncio.get_event_loop().run_in_executor(prepareExecutor, prepare)
output_tensors = await asyncio.get_event_loop().run_in_executor(self.executor, lambda: predict(im))
objs = await asyncio.get_event_loop().run_in_executor(prepareExecutor, lambda: post_process(output_tensors))

ret = self.create_detection_result(objs, src_size, cvss)
return ret

0 comments on commit 6438ad1

Please sign in to comment.