diff --git a/src/concrete/ml/onnx/convert.py b/src/concrete/ml/onnx/convert.py index 96a524d7d..2b076f26f 100644 --- a/src/concrete/ml/onnx/convert.py +++ b/src/concrete/ml/onnx/convert.py @@ -12,6 +12,7 @@ import torch from onnx import checker, helper +from ..common.debugging import assert_true from .onnx_utils import ( IMPLEMENTED_ONNX_OPS, execute_onnx_with_numpy, @@ -149,6 +150,14 @@ def get_equivalent_numpy_forward_from_torch( input_names=arguments, ) equivalent_onnx_model = onnx.load_model(str(output_onnx_file_path)) + + # Check if the inputs are present in the model's graph + for input_name in arguments: + assert_true( + any(input_name == node.name for node in equivalent_onnx_model.graph.input), + f"Input '{input_name}' is not present in the ONNX model. Please check the onnx graph.", + ) + # Remove the tempfile if we used one if use_tempfile: output_onnx_file_path.unlink() diff --git a/src/concrete/ml/onnx/ops_impl.py b/src/concrete/ml/onnx/ops_impl.py index 1b13f8ef7..fa4191b75 100644 --- a/src/concrete/ml/onnx/ops_impl.py +++ b/src/concrete/ml/onnx/ops_impl.py @@ -1479,8 +1479,9 @@ def numpy_pad( def numpy_cast(data: numpy.ndarray, *, to: int) -> Tuple[numpy.ndarray]: """Execute ONNX cast in Numpy. - For traced values during compilation, it supports only booleans, which are converted to float. - For raw values (used in constant folding or shape computations), any cast is allowed. + For traced values during compilation, it supports booleans and floats, + which are converted to float. For raw values (used in constant folding or shape computations), + any cast is allowed. See: https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast @@ -1495,7 +1496,8 @@ def numpy_cast(data: numpy.ndarray, *, to: int) -> Tuple[numpy.ndarray]: if isinstance(data, RawOpOutput): return (data.astype(onnx.helper.tensor_dtype_to_np_dtype(to)).view(RawOpOutput),) - assert_true(to == onnx.TensorProto.BOOL) + # Allow both bool and float types + assert_true(to in (onnx.TensorProto.BOOL, onnx.TensorProto.FLOAT)) # Will be used for traced values return (data.astype(numpy.float64),) diff --git a/src/concrete/ml/quantization/quantized_module.py b/src/concrete/ml/quantization/quantized_module.py index 124bf151e..f3bcef1b1 100644 --- a/src/concrete/ml/quantization/quantized_module.py +++ b/src/concrete/ml/quantization/quantized_module.py @@ -99,28 +99,39 @@ def __init__( # Set base attributes for API consistency. This could be avoided if an abstract base class # is created for both Concrete ML models and QuantizedModule # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/2899 - self.fhe_circuit = None + + all_or_none_params = [ + ordered_module_input_names, + ordered_module_output_names, + quant_layers_dict, + ] + assert_true( + all(v is None or v == {} for v in all_or_none_params) + or not any(v is None or v == {} for v in all_or_none_params), + "All of ordered_module_input_names, ordered_module_output_names, " + "and quant_layers_dict must be provided if any one of them is provided.", + ) + self.ordered_module_input_names = ( + tuple(ordered_module_input_names) if ordered_module_input_names else () + ) + self.ordered_module_output_names = ( + tuple(ordered_module_output_names) if ordered_module_output_names else () + ) + self.quant_layers_dict = ( + copy.deepcopy(quant_layers_dict) if quant_layers_dict is not None else {} + ) + self.input_quantizers: List[UniformQuantizer] = [] + self.output_quantizers: List[UniformQuantizer] = [] + self.fhe_circuit: Union[None, Circuit] = None self._is_compiled = False - self.input_quantizers = [] - self.output_quantizers = [] self._onnx_model = onnx_model self._post_processing_params: Dict[str, Any] = {} - # If any of the arguments are not provided, skip the init - if not all([ordered_module_input_names, ordered_module_output_names, quant_layers_dict]): - return - - # for mypy - assert isinstance(ordered_module_input_names, Iterable) - assert isinstance(ordered_module_output_names, Iterable) - assert all([ordered_module_input_names, ordered_module_output_names, quant_layers_dict]) - self.ordered_module_input_names = tuple(ordered_module_input_names) - self.ordered_module_output_names = tuple(ordered_module_output_names) - - assert quant_layers_dict is not None - self.quant_layers_dict = copy.deepcopy(quant_layers_dict) - - self.output_quantizers = self._set_output_quantizers() + # Initialize output quantizers based on quant_layers_dict + if self.quant_layers_dict: + self.output_quantizers = self._set_output_quantizers() + else: + self.output_quantizers = [] # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4127 def set_reduce_sum_copy(self): diff --git a/tests/torch/test_compile_torch.py b/tests/torch/test_compile_torch.py index d5be6c0d1..cea49a70e 100644 --- a/tests/torch/test_compile_torch.py +++ b/tests/torch/test_compile_torch.py @@ -1373,3 +1373,29 @@ def test_mono_parameter_rounding_warning( verbose=False, get_and_compile=False, ) + + +def test_onnx_no_input(): + """Test a torch model that has no input when converted to onnx.""" + + torch_input = torch.randn(100, 28) + + class SimplifiedNet(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(28, 10) + self.threshold = torch.tensor(0.5, dtype=torch.float32) + + def forward(self, x): + zeros = numpy.zeros_like(x) + x = x + zeros + x = (x > self.threshold).to(torch.float32) + x = self.fc1(x) + return x + + model = SimplifiedNet() + + with pytest.raises(AssertionError) as excinfo: + compile_torch_model(model, torch_input, rounding_threshold_bits=3) + + assert "Input 'x' is not present in the ONNX model" in str(excinfo.value)