Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 2603 #2671

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add onnx model optimization tools to digit classifier
  • Loading branch information
MihirGore23 committed Jul 1, 2024
commit c7eb3b4114f56aa7155844b75c0718b598a6bcb8
51 changes: 41 additions & 10 deletions exercises/static/exercises/dl_digit_classifier/exercise.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import cv2
import numpy as np
import onnxruntime
from onnxruntime.quantization import quantize_dynamic, QuantType
from websocket_server import WebsocketServer

from gui import GUI, ThreadGUI
Expand Down Expand Up @@ -45,7 +46,8 @@ def __init__(self):
self.gui = GUI(self.host, self.hal)

# The process function
def process_dl_model(self, raw_dl_model, roi_scale=0.75, input_size=(28, 28)):
# The process function
def process_dl_model(self, raw_dl_model, roi_scale=0.75, input_size=(28, 28)):
"""
Given a DL model in onnx format, yield prediction per frame.
:param raw_dl_model: raw DL model transferred through websocket
Expand All @@ -60,16 +62,44 @@ def process_dl_model(self, raw_dl_model, roi_scale=0.75, input_size=(28, 28)):
raw_dl_model_bytes = raw_dl_model.encode('ascii')
raw_dl_model_bytes = base64.b64decode(raw_dl_model_bytes)

# Load ONNX model
# Load and optimize ONNX model
ort_session = None
try:
with open(self.aux_model_fname, "wb") as f:
f.write(raw_dl_model_bytes)
ort_session = onnxruntime.InferenceSession(self.aux_model_fname)

# Load the original model
model = onnx.load(self.aux_model_fname)

# Apply optimizations directly using ONNX Runtime
model_optimized = onnx.optimizer.optimize(model, passes=[
"eliminate_identity",
"eliminate_deadend",
"eliminate_nop_dropout",
"eliminate_nop_transpose",
"fuse_bn_into_conv",
"fuse_consecutive_transposes",
"fuse_pad_into_conv",
"fuse_transpose_into_gemm",
"lift_lexical_references",
"nop_elimination",
"split_init"
])

# Save the optimized model
optimized_model_fname = "optimized_model.onnx"
onnx.save(model_optimized, optimized_model_fname)

# Quantize the model
quantized_model_fname = "quantized_model.onnx"
quantize_dynamic(optimized_model_fname, quantized_model_fname, weight_type=QuantType.QInt8)

# Load the quantized model
ort_session = onnxruntime.InferenceSession(quantized_model_fname)
except Exception:
exc_type, exc_value, exc_traceback = sys.exc_info()
print(str(exc_value))
print("ERROR: Model couldn't be loaded")
print("ERROR: Model couldn't be loaded or optimized")

try:
# Init auxiliar variables used for stabilized predictions
Expand Down Expand Up @@ -102,10 +132,10 @@ def process_dl_model(self, raw_dl_model, roi_scale=0.75, input_size=(28, 28)):
pred = int(np.argmax(output, axis=1)) # get the index of the max log-probability

end = time.time()
frame_time = round(end-start, 3)
fps = 1.0/frame_time
frame_time = round(end - start, 3)
fps = 1.0 / frame_time
# number of consecutive frames that must be reached to consider a validprediction
n_consecutive_frames = int(fps/2)
n_consecutive_frames = int(fps / 2)

# For stability, only show digit if detected in more than n consecutive frames
if pred != previous_established_pred:
Expand All @@ -122,9 +152,9 @@ def process_dl_model(self, raw_dl_model, roi_scale=0.75, input_size=(28, 28)):
previous_pred = pred

# Show region used as ROI
cv2.rectangle(input_image,pt2=(w_border, h_border),pt1=(w_border + w_roi, h_border + h_roi),color=(255, 0, 0),thickness=3)
cv2.rectangle(input_image, pt2=(w_border, h_border), pt1=(w_border + w_roi, h_border + h_roi), color=(255, 0, 0), thickness=3)
# Show FPS count
cv2.putText(input_image, "FPS: {}".format(int(fps)), (7,25), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,255,255), 1)
cv2.putText(input_image, "FPS: {}".format(int(fps)), (7, 25), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)

# Send result
self.gui.showResult(input_image, str(previous_established_pred))
Expand All @@ -136,7 +166,7 @@ def process_dl_model(self, raw_dl_model, roi_scale=0.75, input_size=(28, 28)):

self.iteration_counter += 1

# The code should be run for atleast the target time step
# The code should be run for at least the target time step
# If it's less put to sleep
if (ms < self.ideal_cycle):
time.sleep((self.ideal_cycle - ms) / 1000.0)
Expand All @@ -149,6 +179,7 @@ def process_dl_model(self, raw_dl_model, roi_scale=0.75, input_size=(28, 28)):
exc_type, exc_value, exc_traceback = sys.exc_info()
print(str(exc_value))


# Function to measure the frequency of iterations
def measure_frequency(self):
previous_time = datetime.now()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import cv2
import numpy as np
import onnxruntime
from onnxruntime.quantization import quantize_dynamic, QuantType
from websocket_server import WebsocketServer

from gui import GUI, ThreadGUI
Expand Down Expand Up @@ -45,7 +46,8 @@ def __init__(self):
self.gui = GUI(self.host, self.hal)

# The process function
def process_dl_model(self, raw_dl_model, roi_scale=0.75, input_size=(28, 28)):
# The process function
def process_dl_model(self, raw_dl_model, roi_scale=0.75, input_size=(28, 28)):
"""
Given a DL model in onnx format, yield prediction per frame.
:param raw_dl_model: raw DL model transferred through websocket
Expand All @@ -60,16 +62,44 @@ def process_dl_model(self, raw_dl_model, roi_scale=0.75, input_size=(28, 28)):
raw_dl_model_bytes = raw_dl_model.encode('ascii')
raw_dl_model_bytes = base64.b64decode(raw_dl_model_bytes)

# Load ONNX model
# Load and optimize ONNX model
ort_session = None
try:
with open(self.aux_model_fname, "wb") as f:
f.write(raw_dl_model_bytes)
ort_session = onnxruntime.InferenceSession(self.aux_model_fname)

# Load the original model
model = onnx.load(self.aux_model_fname)

# Apply optimizations directly using ONNX Runtime
model_optimized = onnx.optimizer.optimize(model, passes=[
"eliminate_identity",
"eliminate_deadend",
"eliminate_nop_dropout",
"eliminate_nop_transpose",
"fuse_bn_into_conv",
"fuse_consecutive_transposes",
"fuse_pad_into_conv",
"fuse_transpose_into_gemm",
"lift_lexical_references",
"nop_elimination",
"split_init"
])

# Save the optimized model
optimized_model_fname = "optimized_model.onnx"
onnx.save(model_optimized, optimized_model_fname)

# Quantize the model
quantized_model_fname = "quantized_model.onnx"
quantize_dynamic(optimized_model_fname, quantized_model_fname, weight_type=QuantType.QInt8)

# Load the quantized model
ort_session = onnxruntime.InferenceSession(quantized_model_fname)
except Exception:
exc_type, exc_value, exc_traceback = sys.exc_info()
print(str(exc_value))
print("ERROR: Model couldn't be loaded")
print("ERROR: Model couldn't be loaded or optimized")

try:
# Init auxiliar variables used for stabilized predictions
Expand Down Expand Up @@ -102,10 +132,10 @@ def process_dl_model(self, raw_dl_model, roi_scale=0.75, input_size=(28, 28)):
pred = int(np.argmax(output, axis=1)) # get the index of the max log-probability

end = time.time()
frame_time = round(end-start, 3)
fps = 1.0/frame_time
frame_time = round(end - start, 3)
fps = 1.0 / frame_time
# number of consecutive frames that must be reached to consider a validprediction
n_consecutive_frames = int(fps/2)
n_consecutive_frames = int(fps / 2)

# For stability, only show digit if detected in more than n consecutive frames
if pred != previous_established_pred:
Expand All @@ -122,9 +152,9 @@ def process_dl_model(self, raw_dl_model, roi_scale=0.75, input_size=(28, 28)):
previous_pred = pred

# Show region used as ROI
cv2.rectangle(input_image,pt2=(w_border, h_border),pt1=(w_border + w_roi, h_border + h_roi),color=(255, 0, 0),thickness=3)
cv2.rectangle(input_image, pt2=(w_border, h_border), pt1=(w_border + w_roi, h_border + h_roi), color=(255, 0, 0), thickness=3)
# Show FPS count
cv2.putText(input_image, "FPS: {}".format(int(fps)), (7,25), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,255,255), 1)
cv2.putText(input_image, "FPS: {}".format(int(fps)), (7, 25), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)

# Send result
self.gui.showResult(input_image, str(previous_established_pred))
Expand All @@ -136,7 +166,7 @@ def process_dl_model(self, raw_dl_model, roi_scale=0.75, input_size=(28, 28)):

self.iteration_counter += 1

# The code should be run for atleast the target time step
# The code should be run for at least the target time step
# If it's less put to sleep
if (ms < self.ideal_cycle):
time.sleep((self.ideal_cycle - ms) / 1000.0)
Expand All @@ -149,6 +179,7 @@ def process_dl_model(self, raw_dl_model, roi_scale=0.75, input_size=(28, 28)):
exc_type, exc_value, exc_traceback = sys.exc_info()
print(str(exc_value))


# Function to measure the frequency of iterations
def measure_frequency(self):
previous_time = datetime.now()
Expand Down