Skip to content

Commit

Permalink
onnx validation
Browse files Browse the repository at this point in the history
  • Loading branch information
johndpope committed Oct 5, 2024
1 parent 54a159b commit 088b208
Showing 1 changed file with 47 additions and 7 deletions.
54 changes: 47 additions & 7 deletions onnxconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from onnxconverter_common import float16
import onnx
import numpy as np
from onnx import shape_inference
import onnxruntime as ort

class IMFDecoder(nn.Module):
def __init__(self, model):
Expand Down Expand Up @@ -118,7 +120,6 @@ def convert_model_to_32bit(model, output_path):
onnx.save(model, output_path)
print(f"Converted model saved to {output_path}")


def export_to_onnx(model, x_current, x_reference, file_name):
try:
print("Model structure before tracing:")
Expand Down Expand Up @@ -167,18 +168,57 @@ def export_to_onnx(model, x_current, x_reference, file_name):
# Load the ONNX model
onnx_model = onnx.load(file_name)

# Convert int64 to int32
print("Converting int64 to int32...")
# Check the model
print("\nChecking the model...")
onnx.checker.check_model(onnx_model)
print("Model checked successfully")

# Print model input and output shapes
print("\nModel Input and Output Shapes:")
for input in onnx_model.graph.input:
print(f"Input: {input.name}, Shape: {[dim.dim_value for dim in input.type.tensor_type.shape.dim]}")
for output in onnx_model.graph.output:
print(f"Output: {output.name}, Shape: {[dim.dim_value for dim in output.type.tensor_type.shape.dim]}")

# Perform shape inference
print("\nPerforming shape inference...")
inferred_model = shape_inference.infer_shapes(onnx_model)
onnx.save(inferred_model, file_name)
print("Shape inference completed and model saved")

# Convert int64 to int32
print("\nConverting int64 to int32...")
web_compatible_file = file_name.replace('.onnx', '_web.onnx')
onnx_model = convert_model_to_32bit(onnx_model,web_compatible_file)


convert_model_to_32bit(onnx_model, web_compatible_file)

# Validate the converted model
print("\nValidating the converted model...")
onnx.checker.check_model(onnx.load(web_compatible_file))
print("Converted model validated successfully")

# Test the model with ONNX Runtime
print("\nTesting the model with ONNX Runtime...")
ort_session = ort.InferenceSession(web_compatible_file)

# Prepare inputs (assuming x_current and x_reference are PyTorch tensors)
ort_inputs = {
'x_current': x_current.numpy(),
'x_reference': x_reference.numpy()
}

# Run inference
ort_outputs = ort_session.run(None, ort_inputs)
print("ONNX Runtime inference successful")

print(f"\nConverted and validated model saved to {web_compatible_file}")
print("This model should now be compatible with WONNX")

except Exception as e:
print(f"Error during ONNX export: {str(e)}")
print(f"Error during ONNX export and validation: {str(e)}")
import traceback
traceback.print_exc()


# Load your model
model = IMFModel()
model.eval()
Expand Down

0 comments on commit 088b208

Please sign in to comment.