Skip to content

Commit

Permalink
fix: working tests
Browse files Browse the repository at this point in the history
  • Loading branch information
andrei-stoian-zama committed Jul 16, 2024
1 parent be84cb3 commit 844df5d
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 6 deletions.
13 changes: 12 additions & 1 deletion src/concrete/ml/onnx/ops_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

# pylint: disable=ungrouped-imports
from concrete.ml.common import utils
from concrete.ml.common.debugging import assert_true
from concrete.ml.common.debugging import assert_false, assert_true
from concrete.ml.onnx.onnx_impl_utils import (
compute_onnx_pool_padding,
numpy_onnx_pad,
Expand Down Expand Up @@ -1808,6 +1808,17 @@ def numpy_brevitas_quant(
assert_true(signed in (1, 0), "Signed flag in Brevitas quantizer must be 0/1")
assert_true(narrow in (1, 0), "Narrow range flag in Brevitas quantizer must be 0/1")

# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4544
# Remove this workaround when brevitas export is fixed
if signed == 0 and narrow == 1:
signed = 1
narrow = 0

assert_false(
signed == 0 and narrow == 1,
"Can not use narrow range for non-signed Brevitas quantizers",
)

# Compute the re-scaled values
y = x / scale
y = y + zero_point
Expand Down
13 changes: 12 additions & 1 deletion src/concrete/ml/quantization/quantized_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from concrete.fhe import tag, univariate, zeros
from typing_extensions import SupportsIndex

from ..common.debugging import assert_true
from ..common.debugging import assert_false, assert_true
from ..onnx.onnx_impl_utils import (
compute_onnx_pool_padding,
numpy_onnx_pad,
Expand Down Expand Up @@ -1974,6 +1974,17 @@ def check_float(v, err_msg):
self.is_signed = bool(attrs["signed"])
self.is_narrow = bool(attrs["narrow"])

# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4544
# Remove this workaround when brevitas export is fixed
if self.is_signed is False and self.is_narrow is False:
self.is_signed = True
self.is_narrow = False

assert_false(
not self.is_signed and self.is_narrow,
"Can not use narrow range for non-signed Brevitas quantizers",
)

# To ensure de-quantization produces floats, the following parameters must be float.
# This should be default export setting in Brevitas
check_float(
Expand Down
3 changes: 3 additions & 0 deletions src/concrete/ml/sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1374,6 +1374,9 @@ def from_sklearn_model(

# Get the onnx model, all operations needed to load it properly will be done on it.
n_features = model.n_features_in_
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4545
# Execute with 2 example for efficiency in large data scenarios to prevent slowdown
# but also to work around the HB export issue.
dummy_input = numpy.zeros((2, n_features))
framework = "xgboost" if isinstance(sklearn_model, XGBModel) else "sklearn"
onnx_model = get_onnx_model(
Expand Down
4 changes: 3 additions & 1 deletion src/concrete/ml/sklearn/tree_to_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,9 @@ def tree_to_numpy(
f"framework={framework} is not supported. It must be either 'xgboost' or 'sklearn'",
)

# Execute with 1 example for efficiency in large data scenarios to prevent slowdown
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4545
# Execute with 2 example for efficiency in large data scenarios to prevent slowdown
# but also to work around the HB export issue
onnx_model = get_onnx_model(model, x[:2] if x.shape[0] > 1 else x, framework)

# Compute for tree-based models the LSB to remove in stage 1 and stage 2
Expand Down
8 changes: 8 additions & 0 deletions src/concrete/ml/sklearn/xgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,10 @@ def load_dict(cls, metadata: Dict):
obj.onnx_model_ = metadata["onnx_model_"]
obj.output_quantizers = metadata["output_quantizers"]
obj._fhe_ensembling = metadata["_fhe_ensembling"]

# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4545
# Execute with 2 example for efficiency in large data scenarios to prevent slowdown
# but also to work around the HB export issue.
obj._tree_inference = tree_to_numpy(
obj.sklearn_model,
numpy.tile(numpy.zeros((len(obj.input_quantizers),))[None, ...], [2, 1]),
Expand Down Expand Up @@ -492,6 +496,10 @@ def load_dict(cls, metadata: Dict):
obj.onnx_model_ = metadata["onnx_model_"]
obj.output_quantizers = metadata["output_quantizers"]
obj._fhe_ensembling = metadata["_fhe_ensembling"]

# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4545
# Execute with 2 example for efficiency in large data scenarios to prevent slowdown
# but also to work around the HB export issue.
obj._tree_inference = tree_to_numpy(
obj.sklearn_model,
numpy.tile(numpy.zeros((len(obj.input_quantizers),))[None, ...], [2, 1]),
Expand Down
9 changes: 6 additions & 3 deletions tests/quantization/test_quantized_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1658,9 +1658,12 @@ def create_layer(is_signed, narrow):
)

if not is_signed and narrow:
with pytest.raises(AssertionError, match=r"Can not use narrow range.*"):
quant = create_layer(1 if is_signed else 0, 1 if narrow else 0)
return
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4544
# Reinstate warning check when brevitas export is fixed
pytest.skip("Skipping checking of invalid brevitas quant setting (signed=0,narrow=1)")
# with pytest.raises(AssertionError, match=r"Can not use narrow range.*"):
# quant = create_layer(1 if is_signed else 0, 1 if narrow else 0)
# return

quant = create_layer(1 if is_signed else 0, 1 if narrow else 0)

Expand Down

0 comments on commit 844df5d

Please sign in to comment.