Skip to content

Commit

Permalink
chore: refacto quantized module init + handle no input onnx
Browse files Browse the repository at this point in the history
  • Loading branch information
jfrery authored Jan 26, 2024
1 parent 115ddbd commit e6d3f35
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 22 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ filterwarnings = [
"ignore:non-integer arguments to randrange\\(\\) have been deprecated since Python 3\\.10 and will be removed in a subsequent version:DeprecationWarning",
"ignore:Deprecated call to `pkg_resources.declare_namespace:DeprecationWarning",
"ignore:The --rsyncdir command line argument and rsyncdirs config variable are deprecated.:DeprecationWarning",
"ignore:Converting a tensor to a NumPy array might cause the trace to be incorrect.",
"ignore:torch.from_numpy results are registered as constants in the trace.",
]

[tool.semantic_release]
Expand Down
10 changes: 10 additions & 0 deletions src/concrete/ml/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch
from onnx import checker, helper

from ..common.debugging import assert_true
from .onnx_utils import (
IMPLEMENTED_ONNX_OPS,
execute_onnx_with_numpy,
Expand Down Expand Up @@ -149,6 +150,15 @@ def get_equivalent_numpy_forward_from_torch(
input_names=arguments,
)
equivalent_onnx_model = onnx.load_model(str(output_onnx_file_path))

# Check if the inputs are present in the model's graph
for input_name in arguments:
assert_true(
any(input_name == node.name for node in equivalent_onnx_model.graph.input),
f"Input '{input_name}' is missing in the ONNX graph after export. "
"Verify the forward pass for issues.",
)

# Remove the tempfile if we used one
if use_tempfile:
output_onnx_file_path.unlink()
Expand Down
13 changes: 9 additions & 4 deletions src/concrete/ml/onnx/ops_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1479,9 +1479,9 @@ def numpy_pad(
def numpy_cast(data: numpy.ndarray, *, to: int) -> Tuple[numpy.ndarray]:
"""Execute ONNX cast in Numpy.
For traced values during compilation, it supports only booleans, which are converted to float.
For raw values (used in constant folding or shape computations), any cast is allowed.
This function supports casting to booleans, floats, and double for traced values,
converting them accordingly. For raw values (used in constant folding or shape computations),
any cast is allowed.
See: https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast
Args:
Expand All @@ -1495,7 +1495,12 @@ def numpy_cast(data: numpy.ndarray, *, to: int) -> Tuple[numpy.ndarray]:
if isinstance(data, RawOpOutput):
return (data.astype(onnx.helper.tensor_dtype_to_np_dtype(to)).view(RawOpOutput),)

assert_true(to == onnx.TensorProto.BOOL)
allowed_types = (onnx.TensorProto.BOOL, onnx.TensorProto.FLOAT, onnx.TensorProto.DOUBLE)
assert to in allowed_types, (
f"Invalid 'to' data type: {onnx.TensorProto.DataType.Name(to)}. "
f"Only {', '.join(onnx.TensorProto.DataType.Name(t) for t in allowed_types)}"
"are allowed for casting."
)

# Will be used for traced values
return (data.astype(numpy.float64),)
Expand Down
51 changes: 33 additions & 18 deletions src/concrete/ml/quantization/quantized_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,31 +96,46 @@ def __init__(
quant_layers_dict: Optional[Dict[str, Tuple[Tuple[str, ...], QuantizedOp]]] = None,
onnx_model: Optional[onnx.ModelProto] = None,
):

all_or_none_params = [
ordered_module_input_names,
ordered_module_output_names,
quant_layers_dict,
]
if not (
all(v is None or v == {} for v in all_or_none_params)
or not any(v is None or v == {} for v in all_or_none_params)
):
raise ValueError(
"Please either set all three 'ordered_module_input_names', "
"'ordered_module_output_names' and 'quant_layers_dict' or none of them."
)

self.ordered_module_input_names = (
tuple(ordered_module_input_names) if ordered_module_input_names else ()
)
self.ordered_module_output_names = (
tuple(ordered_module_output_names) if ordered_module_output_names else ()
)
self.quant_layers_dict = (
copy.deepcopy(quant_layers_dict) if quant_layers_dict is not None else {}
)

# Set base attributes for API consistency. This could be avoided if an abstract base class
# is created for both Concrete ML models and QuantizedModule
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/2899
self.fhe_circuit = None
self.input_quantizers: List[UniformQuantizer] = []
self.output_quantizers: List[UniformQuantizer] = []
self.fhe_circuit: Optional[Circuit] = None
self._is_compiled = False
self.input_quantizers = []
self.output_quantizers = []
self._onnx_model = onnx_model
self._post_processing_params: Dict[str, Any] = {}

# If any of the arguments are not provided, skip the init
if not all([ordered_module_input_names, ordered_module_output_names, quant_layers_dict]):
return

# for mypy
assert isinstance(ordered_module_input_names, Iterable)
assert isinstance(ordered_module_output_names, Iterable)
assert all([ordered_module_input_names, ordered_module_output_names, quant_layers_dict])
self.ordered_module_input_names = tuple(ordered_module_input_names)
self.ordered_module_output_names = tuple(ordered_module_output_names)

assert quant_layers_dict is not None
self.quant_layers_dict = copy.deepcopy(quant_layers_dict)

self.output_quantizers = self._set_output_quantizers()
# Initialize output quantizers based on quant_layers_dict
if self.quant_layers_dict:
self.output_quantizers = self._set_output_quantizers()
else:
self.output_quantizers = []

# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4127
def set_reduce_sum_copy(self):
Expand Down
15 changes: 15 additions & 0 deletions tests/quantization/test_quantized_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,3 +496,18 @@ def test_serialization(model_class, input_shape):
QuantizedModule,
equal_method=partial(quantized_module_predictions_are_equal, x=numpy_input),
)


def test_quantized_module_initialization_error():
"""Test initialization fails with mismatched parameters."""
# Initialize with invalid parameters
with pytest.raises(
ValueError,
match=r"Please either set all three 'ordered_module_input_names', "
r"'ordered_module_output_names' and 'quant_layers_dict' or none of them.",
):
QuantizedModule(
ordered_module_input_names=["input1", "input2"],
ordered_module_output_names=None, # This makes the combination invalid
quant_layers_dict={"layer1": (["input1"], "QuantizedOp")},
)
64 changes: 64 additions & 0 deletions tests/torch/test_compile_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1373,3 +1373,67 @@ def test_mono_parameter_rounding_warning(
verbose=False,
get_and_compile=False,
)


@pytest.mark.parametrize(
"cast_type, should_fail, error_message",
[
(torch.bool, False, None),
(torch.float32, False, None),
(torch.float64, False, None),
(torch.int64, True, r"Invalid 'to' data type: INT64"),
],
)
def test_compile_torch_model_with_cast(cast_type, should_fail, error_message):
"""Test compiling a Torch model with various casts, expecting failure for invalid types."""
torch_input = torch.randn(100, 28)

class CastNet(nn.Module):
"""Network with cast."""

def __init__(self, cast_to):
super().__init__()
self.threshold = torch.tensor(0.5, dtype=torch.float32)
self.cast_to = cast_to

def forward(self, x):
"""Forward pass with dynamic cast."""
zeros = torch.zeros_like(x)
x = x + zeros
x = (x > self.threshold).to(self.cast_to)
return x

model = CastNet(cast_type)

if should_fail:
with pytest.raises(AssertionError, match=error_message):
compile_torch_model(model, torch_input, cast_type, rounding_threshold_bits=3)
else:
compile_torch_model(model, torch_input, cast_type, rounding_threshold_bits=3)


def test_onnx_no_input():
"""Test a torch model that has no input when converted to onnx."""

torch_input = torch.randn(100, 28)

class NoInputNet(nn.Module):
"""Network with no input in the onnx graph."""

def __init__(self):
super().__init__()
self.threshold = torch.tensor(0.5, dtype=torch.float32)

def forward(self, x):
"""Forward pass."""
zeros = numpy.zeros_like(x)
x = x + zeros
x = (x > self.threshold).to(torch.float32)
return x

model = NoInputNet()

with pytest.raises(
AssertionError, match="Input 'x' is missing in the ONNX graph after export."
):
compile_torch_model(model, torch_input, rounding_threshold_bits=3)

0 comments on commit e6d3f35

Please sign in to comment.