Skip to content

Commit

Permalink
chore: review have a ValueError instead of assert_true in qmodule init
Browse files Browse the repository at this point in the history
  • Loading branch information
jfrery committed Jan 24, 2024
1 parent 97ae7a5 commit 0740717
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 22 deletions.
3 changes: 2 additions & 1 deletion src/concrete/ml/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ def get_equivalent_numpy_forward_from_torch(
for input_name in arguments:
assert_true(
any(input_name == node.name for node in equivalent_onnx_model.graph.input),
f"Input '{input_name}' is not present in the ONNX model. Please check the onnx graph.",
f"Input '{input_name}' is missing in the ONNX graph after export. "
"Verify the forward pass for issues.",
)

# Remove the tempfile if we used one
Expand Down
13 changes: 8 additions & 5 deletions src/concrete/ml/onnx/ops_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1479,10 +1479,9 @@ def numpy_pad(
def numpy_cast(data: numpy.ndarray, *, to: int) -> Tuple[numpy.ndarray]:
"""Execute ONNX cast in Numpy.
For traced values during compilation, it supports booleans and floats,
which are converted to float. For raw values (used in constant folding or shape computations),
This function supports casting to booleans, floats, and double for traced values,
converting them accordingly. For raw values (used in constant folding or shape computations),
any cast is allowed.
See: https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast
Args:
Expand All @@ -1496,8 +1495,12 @@ def numpy_cast(data: numpy.ndarray, *, to: int) -> Tuple[numpy.ndarray]:
if isinstance(data, RawOpOutput):
return (data.astype(onnx.helper.tensor_dtype_to_np_dtype(to)).view(RawOpOutput),)

# Allow both bool and float types
assert_true(to in (onnx.TensorProto.BOOL, onnx.TensorProto.FLOAT))
allowed_types = (onnx.TensorProto.BOOL, onnx.TensorProto.FLOAT, onnx.TensorProto.DOUBLE)
assert to in allowed_types, (
f"Invalid 'to' data type: {onnx.TensorProto.DataType.Name(to)}. "
f"Only {', '.join(onnx.TensorProto.DataType.Name(t) for t in allowed_types)}"
"are allowed for casting."
)

# Will be used for traced values
return (data.astype(numpy.float64),)
Expand Down
22 changes: 13 additions & 9 deletions src/concrete/ml/quantization/quantized_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,21 +96,21 @@ def __init__(
quant_layers_dict: Optional[Dict[str, Tuple[Tuple[str, ...], QuantizedOp]]] = None,
onnx_model: Optional[onnx.ModelProto] = None,
):
# Set base attributes for API consistency. This could be avoided if an abstract base class
# is created for both Concrete ML models and QuantizedModule
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/2899

all_or_none_params = [
ordered_module_input_names,
ordered_module_output_names,
quant_layers_dict,
]
assert_true(
if not (
all(v is None or v == {} for v in all_or_none_params)
or not any(v is None or v == {} for v in all_or_none_params),
"All of ordered_module_input_names, ordered_module_output_names, "
"and quant_layers_dict must be provided if any one of them is provided.",
)
or not any(v is None or v == {} for v in all_or_none_params)
):
raise ValueError(
"Please either set all three 'ordered_module_input_names', "
"'ordered_module_output_names' and 'quant_layers_dict' or none of them."
)

self.ordered_module_input_names = (
tuple(ordered_module_input_names) if ordered_module_input_names else ()
)
Expand All @@ -120,9 +120,13 @@ def __init__(
self.quant_layers_dict = (
copy.deepcopy(quant_layers_dict) if quant_layers_dict is not None else {}
)

# Set base attributes for API consistency. This could be avoided if an abstract base class
# is created for both Concrete ML models and QuantizedModule
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/2899
self.input_quantizers: List[UniformQuantizer] = []
self.output_quantizers: List[UniformQuantizer] = []
self.fhe_circuit: Union[None, Circuit] = None
self.fhe_circuit: Optional[Circuit] = None
self._is_compiled = False
self._onnx_model = onnx_model
self._post_processing_params: Dict[str, Any] = {}
Expand Down
15 changes: 15 additions & 0 deletions tests/quantization/test_quantized_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,3 +496,18 @@ def test_serialization(model_class, input_shape):
QuantizedModule,
equal_method=partial(quantized_module_predictions_are_equal, x=numpy_input),
)


def test_quantized_module_initialization_error():
"""Test initialization fails with mismatched parameters."""
# Initialize with invalid parameters
with pytest.raises(
ValueError,
match=r"Please either set all three 'ordered_module_input_names', "
r"'ordered_module_output_names' and 'quant_layers_dict' or none of them.",
):
QuantizedModule(
ordered_module_input_names=["input1", "input2"],
ordered_module_output_names=None, # This makes the combination invalid
quant_layers_dict={"layer1": (["input1"], "QuantizedOp")},
)
49 changes: 42 additions & 7 deletions tests/torch/test_compile_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1375,30 +1375,65 @@ def test_mono_parameter_rounding_warning(
)


@pytest.mark.parametrize(
"cast_type, should_fail, error_message",
[
(torch.bool, False, None),
(torch.float32, False, None),
(torch.float64, False, None),
(torch.int64, True, r"Invalid 'to' data type: INT64"),
],
)
def test_compile_torch_model_with_cast(cast_type, should_fail, error_message):
"""Test compiling a Torch model with various casts, expecting failure for invalid types."""
torch_input = torch.randn(100, 28)

class CastNet(nn.Module):
"""Network with cast."""

def __init__(self, cast_to):
super().__init__()
self.threshold = torch.tensor(0.5, dtype=torch.float32)
self.cast_to = cast_to

def forward(self, x):
"""Forward pass with dynamic cast."""
zeros = torch.zeros_like(x)
x = x + zeros
x = (x > self.threshold).to(self.cast_to)
return x

model = CastNet(cast_type)

if should_fail:
with pytest.raises(AssertionError, match=error_message):
compile_torch_model(model, torch_input, cast_type, rounding_threshold_bits=3)
else:
compile_torch_model(model, torch_input, cast_type, rounding_threshold_bits=3)


def test_onnx_no_input():
"""Test a torch model that has no input when converted to onnx."""

torch_input = torch.randn(100, 28)

class SimplifiedNet(nn.Module):
class NoInputNet(nn.Module):
"""Network with no input in the onnx graph."""

def __init__(self):
super().__init__()
self.fc1 = nn.Linear(28, 10)
self.threshold = torch.tensor(0.5, dtype=torch.float32)

def forward(self, x):
"""Forward pass."""
zeros = numpy.zeros_like(x)
x = x + zeros
x = (x > self.threshold).to(torch.float32)
x = self.fc1(x)
return x

model = SimplifiedNet()
model = NoInputNet()

with pytest.raises(AssertionError) as excinfo:
with pytest.raises(
AssertionError, match="Input 'x' is missing in the ONNX graph after export."
):
compile_torch_model(model, torch_input, rounding_threshold_bits=3)

assert "Input 'x' is not present in the ONNX model" in str(excinfo.value)

0 comments on commit 0740717

Please sign in to comment.