Skip to content

Commit

Permalink
chore: refacto quantized module init + fix no input onnx
Browse files Browse the repository at this point in the history
  • Loading branch information
jfrery committed Jan 23, 2024
1 parent a58f7d0 commit d5226d5
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 21 deletions.
9 changes: 9 additions & 0 deletions src/concrete/ml/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch
from onnx import checker, helper

from ..common.debugging import assert_true
from .onnx_utils import (
IMPLEMENTED_ONNX_OPS,
execute_onnx_with_numpy,
Expand Down Expand Up @@ -149,6 +150,14 @@ def get_equivalent_numpy_forward_from_torch(
input_names=arguments,
)
equivalent_onnx_model = onnx.load_model(str(output_onnx_file_path))

# Check if the inputs are present in the model's graph
for input_name in arguments:
assert_true(
any(input_name == node.name for node in equivalent_onnx_model.graph.input),
f"Input '{input_name}' is not present in the ONNX model. Please check the onnx graph.",
)

# Remove the tempfile if we used one
if use_tempfile:
output_onnx_file_path.unlink()
Expand Down
8 changes: 5 additions & 3 deletions src/concrete/ml/onnx/ops_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1479,8 +1479,9 @@ def numpy_pad(
def numpy_cast(data: numpy.ndarray, *, to: int) -> Tuple[numpy.ndarray]:
"""Execute ONNX cast in Numpy.
For traced values during compilation, it supports only booleans, which are converted to float.
For raw values (used in constant folding or shape computations), any cast is allowed.
For traced values during compilation, it supports booleans and floats,
which are converted to float. For raw values (used in constant folding or shape computations),
any cast is allowed.
See: https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast
Expand All @@ -1495,7 +1496,8 @@ def numpy_cast(data: numpy.ndarray, *, to: int) -> Tuple[numpy.ndarray]:
if isinstance(data, RawOpOutput):
return (data.astype(onnx.helper.tensor_dtype_to_np_dtype(to)).view(RawOpOutput),)

assert_true(to == onnx.TensorProto.BOOL)
# Allow both bool and float types
assert_true(to in (onnx.TensorProto.BOOL, onnx.TensorProto.FLOAT))

# Will be used for traced values
return (data.astype(numpy.float64),)
Expand Down
47 changes: 29 additions & 18 deletions src/concrete/ml/quantization/quantized_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,28 +99,39 @@ def __init__(
# Set base attributes for API consistency. This could be avoided if an abstract base class
# is created for both Concrete ML models and QuantizedModule
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/2899
self.fhe_circuit = None

all_or_none_params = [
ordered_module_input_names,
ordered_module_output_names,
quant_layers_dict,
]
assert_true(
all(v is None or v == {} for v in all_or_none_params)
or not any(v is None or v == {} for v in all_or_none_params),
"All of ordered_module_input_names, ordered_module_output_names, "
"and quant_layers_dict must be provided if any one of them is provided.",
)
self.ordered_module_input_names = (
tuple(ordered_module_input_names) if ordered_module_input_names else ()
)
self.ordered_module_output_names = (
tuple(ordered_module_output_names) if ordered_module_output_names else ()
)
self.quant_layers_dict = (
copy.deepcopy(quant_layers_dict) if quant_layers_dict is not None else {}
)
self.input_quantizers: List[UniformQuantizer] = []
self.output_quantizers: List[UniformQuantizer] = []
self.fhe_circuit: Union[None, Circuit] = None
self._is_compiled = False
self.input_quantizers = []
self.output_quantizers = []
self._onnx_model = onnx_model
self._post_processing_params: Dict[str, Any] = {}

# If any of the arguments are not provided, skip the init
if not all([ordered_module_input_names, ordered_module_output_names, quant_layers_dict]):
return

# for mypy
assert isinstance(ordered_module_input_names, Iterable)
assert isinstance(ordered_module_output_names, Iterable)
assert all([ordered_module_input_names, ordered_module_output_names, quant_layers_dict])
self.ordered_module_input_names = tuple(ordered_module_input_names)
self.ordered_module_output_names = tuple(ordered_module_output_names)

assert quant_layers_dict is not None
self.quant_layers_dict = copy.deepcopy(quant_layers_dict)

self.output_quantizers = self._set_output_quantizers()
# Initialize output quantizers based on quant_layers_dict
if self.quant_layers_dict:
self.output_quantizers = self._set_output_quantizers()
else:
self.output_quantizers = []

# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4127
def set_reduce_sum_copy(self):
Expand Down
26 changes: 26 additions & 0 deletions tests/torch/test_compile_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1373,3 +1373,29 @@ def test_mono_parameter_rounding_warning(
verbose=False,
get_and_compile=False,
)


def test_onnx_no_input():
"""Test a torch model that has no input when converted to onnx."""

torch_input = torch.randn(100, 28)

class SimplifiedNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(28, 10)
self.threshold = torch.tensor(0.5, dtype=torch.float32)

def forward(self, x):
zeros = numpy.zeros_like(x)
x = x + zeros
x = (x > self.threshold).to(torch.float32)
x = self.fc1(x)
return x

model = SimplifiedNet()

with pytest.raises(AssertionError) as excinfo:
compile_torch_model(model, torch_input, rounding_threshold_bits=3)

assert "Input 'x' is not present in the ONNX model" in str(excinfo.value)

0 comments on commit d5226d5

Please sign in to comment.