Skip to content

Commit

Permalink
Unit tests for ONNX parser input tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
jezsadler committed Oct 20, 2023
1 parent d65b928 commit 9c3d356
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/omlt/io/onnx_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,14 @@ def parse_network(self, graph, scaling_object, input_bounds):
size.append(dim.dim_value)
dim_value *= dim.dim_value
if dim_value is None:
raise ValueError(f"All dimensions in {graph} input tensor have 0 value.")
raise ValueError(f"All dimensions in graph \"{graph.name}\" input tensor have 0 value.")
assert network_input is None
network_input = InputLayer(size)
self._node_map[input.name] = network_input
network.add_layer(network_input)

if network_input is None:
raise ValueError(f"No valid input layer found in {graph}.")
raise ValueError(f"No valid input layer found in graph \"{graph.name}\".")

self._nodes = nodes
self._nodes_by_output = nodes_by_output
Expand Down Expand Up @@ -116,7 +116,7 @@ def parse_network(self, graph, scaling_object, input_bounds):
value = _parse_constant_value(node)
self._constants[output] = value
else:
raise ValueError(f"Nodes must have inputs or have op_type \"Constant\". {node} has no inputs and op_type {node.op_type}")
raise ValueError(f"Nodes must have inputs or have op_type \"Constant\". Node \"{node.name}\" has no inputs and op_type \"{node.op_type}\".")

# traverse graph
self._node_stack = list(inputs)
Expand Down
32 changes: 32 additions & 0 deletions tests/io/test_onnx_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

if onnx_available:
from omlt.io.onnx import load_onnx_neural_network
from omlt.io.onnx_parser import NetworkParser


@pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test")
Expand Down Expand Up @@ -105,3 +106,34 @@ def test_maxpool(datadir):
assert layers[3].output_size == [3, 2, 1]
for layer in layers[1:]:
assert layer.kernel_depth == 3

@pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test")
def test_input_tensor_invalid_dims(datadir):
model = onnx.load(datadir.file("keras_linear_131.onnx"))
model.graph.input[0].type.tensor_type.shape.dim[1].dim_value = 0
parser = NetworkParser()
with pytest.raises(ValueError) as excinfo:
parser.parse_network(model.graph,None,None)
expected_msg = "All dimensions in graph \"tf2onnx\" input tensor have 0 value."
assert str(excinfo.value) == expected_msg

@pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test")
def test_no_input_layers(datadir):
model = onnx.load(datadir.file("keras_linear_131.onnx"))
model.graph.input.remove(model.graph.input[0])
parser = NetworkParser()
with pytest.raises(ValueError) as excinfo:
parser.parse_network(model.graph,None,None)
expected_msg = "No valid input layer found in graph \"tf2onnx\"."
assert str(excinfo.value) == expected_msg

@pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test")
def test_node_no_inputs(datadir):
model = onnx.load(datadir.file("keras_linear_131.onnx"))
while (len(model.graph.node[0].input) > 0):
model.graph.node[0].input.pop()
parser = NetworkParser()
with pytest.raises(ValueError) as excinfo:
parser.parse_network(model.graph,None,None)
expected_msg = "Nodes must have inputs or have op_type \"Constant\". Node \"StatefulPartitionedCall/keras_linear_131/dense/MatMul\" has no inputs and op_type \"MatMul\"."
assert str(excinfo.value) == expected_msg

0 comments on commit 9c3d356

Please sign in to comment.