Skip to content

Commit

Permalink
chore: update
Browse files Browse the repository at this point in the history
  • Loading branch information
kcelia committed Jan 30, 2024
1 parent 92e3fb9 commit 4e2fcda
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 89 deletions.
6 changes: 3 additions & 3 deletions src/concrete/ml/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import tempfile
import warnings
from pathlib import Path
from typing import Callable, Optional, Tuple, Union
from typing import Callable, Tuple, Union

import numpy
import onnx
Expand Down Expand Up @@ -255,7 +255,7 @@ def get_equivalent_numpy_forward_from_onnx(
def get_equivalent_numpy_forward_from_onnx_tree(
onnx_model: onnx.ModelProto,
check_model: bool = True,
auto_truncate = None,
auto_truncate=None,
) -> Tuple[Callable[..., Tuple[numpy.ndarray, ...]], onnx.ModelProto]:
"""Get the numpy equivalent forward of the provided ONNX model for tree-based models only.
Expand All @@ -264,7 +264,7 @@ def get_equivalent_numpy_forward_from_onnx_tree(
forward.
check_model (bool): set to True to run the onnx checker on the model.
Defaults to True.
auto_truncate: This parameter is exclusively used for
auto_truncate (TODO): This parameter is exclusively used for
optimizing tree-based models. It contains the values of the least significant bits to
remove during the tree traversal, where the first value refers to the first comparison
(either "less" or "less_or_equal"), while the second value refers to the "Equal"
Expand Down
27 changes: 6 additions & 21 deletions src/concrete/ml/onnx/onnx_impl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
from typing import Callable, Tuple, Union

import numpy

from concrete.fhe import truncate_bit_pattern, round_bit_pattern
from concrete.fhe import conv as fhe_conv
from concrete.fhe import ones as fhe_ones
from concrete.fhe import truncate_bit_pattern
from concrete.fhe.tracing import Tracer

from ..common.debugging import assert_true
Expand Down Expand Up @@ -241,21 +240,21 @@ def onnx_avgpool_compute_norm_const(
def rounded_comparison(
x: numpy.ndarray, y: numpy.ndarray, auto_truncate, operation: ComparisonOperationType
) -> Tuple[bool]:
"""Comparison operation using `round_bit_pattern` function.
"""Comparison operation using `truncate_bit_pattern` function.
`round_bit_pattern` rounds the bit pattern of an integer to the closer
`truncate_bit_pattern` rounds the bit pattern of an integer to the closer
It also checks for any potential overflow. If so, it readjusts the LSBs accordingly.
The parameter `lsbs_to_remove` in `round_bit_pattern` can either be an integer specifying the
The parameter `lsbs_to_remove` in `truncate_bit_pattern` can either be an integer specifying the
number of LSBS to remove, or an `AutoRounder` object that determines the required number of LSBs
based on the specified number of MSBs to retain. But in our case, we choose to compute the LSBs
manually.
Args:
x (numpy.ndarray): Input tensor
y (numpy.ndarray): Input tensor
lsbs_to_remove (int): Number of the least significant bits to remove
operation (ComparisonOperationType): Comparison operation, which can `<`, `<=` and `==`
auto_truncate: TODO
operation: TODO
Returns:
Tuple[bool]: If x and y satisfy the comparison operator.
Expand All @@ -265,17 +264,3 @@ def rounded_comparison(
rounded_subtraction = truncate_bit_pattern(x - y, lsbs_to_remove=auto_truncate)

return (operation(rounded_subtraction),)

# # To determine if 'x' 'operation' 'y' (operation being <, >, >=, <=), we evaluate 'x - y'
# rounded_subtraction = truncate_bit_pattern(x - y, lsbs_to_remove=auto_truncate)
# assert isinstance(lsbs_to_remove, int)

# # Workaround: in this context, `round_bit_pattern` is used as a truncate operation.
# # Consequently, we subtract a term, called `half` that will subsequently be re-added during the
# # `round_bit_pattern` process.
# half = 1 << (lsbs_to_remove - 1)

# # To determine if 'x' 'operation' 'y' (operation being <, >, >=, <=), we evaluate 'x - y'
# rounded_subtraction = round_bit_pattern((x - y) - half, lsbs_to_remove=lsbs_to_remove)

# return (operation(rounded_subtraction),)
4 changes: 2 additions & 2 deletions src/concrete/ml/onnx/onnx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@

# Original file:
# https://github.com/google/jax/blob/f6d329b2d9b5f83c6a59e5739aa1ca8d4d1ffa1c/examples/onnx2xla.py
from typing import Any, Callable, Dict, Optional, Tuple
from typing import Any, Callable, Dict, Tuple

import numpy
import onnx
Expand Down Expand Up @@ -413,7 +413,7 @@
}
# All numpy operators used for tree-based models that support auto rounding
ONNX_COMPARISON_OPS_TO_ROUNDED_TREES_NUMPY_IMPL_BOOL = {
"Less": rounded_numpy_less_for_trees,
"Less": rounded_numpy_less_for_trees, # type: ignore[dict-item]
"Equal": rounded_numpy_equal_for_trees,
"LessOrEqual": rounded_numpy_less_or_equal_for_trees,
}
Expand Down
24 changes: 8 additions & 16 deletions src/concrete/ml/onnx/ops_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,7 +892,7 @@ def rounded_numpy_equal_for_trees(
x: numpy.ndarray,
y: numpy.ndarray,
*,
auto_truncate = None,
auto_truncate=None,
) -> Tuple[numpy.ndarray]:
"""Compute rounded equal in numpy according to ONNX spec for tree-based models only.
Expand Down Expand Up @@ -925,7 +925,6 @@ def rounded_numpy_equal_for_trees(
return (numpy.equal(x, y),)



def numpy_equal_float(
x: numpy.ndarray,
y: numpy.ndarray,
Expand Down Expand Up @@ -1096,7 +1095,6 @@ def rounded_numpy_less_for_trees(
# numpy.less(x, y) is equivalent to :
# x - y <= 0 => round_bit_pattern(x - y - half) < 0
if auto_truncate is not None:
#print("Use truncate for <")
return rounded_comparison(x, y, auto_truncate, operation=lambda x: x < 0)

# Else, default numpy_less operator
Expand Down Expand Up @@ -1145,7 +1143,7 @@ def rounded_numpy_less_or_equal_for_trees(
x: numpy.ndarray,
y: numpy.ndarray,
*,
auto_truncate = None,
auto_truncate=None,
) -> Tuple[numpy.ndarray]:
"""Compute rounded less or equal in numpy according to ONNX spec for tree-based models only.
Expand All @@ -1161,21 +1159,15 @@ def rounded_numpy_less_or_equal_for_trees(
Tuple[numpy.ndarray]: Output tensor
"""

# numpy.less_equal(x, y) <= 0 is equivalent to :
# np.less_equal(x, y), truncate_bit_pattern((y - x), lsbs_to_remove=r) >= 0
# option 1: x - y <= 0 => round_bit_pattern(x - y) <= 0
# gives bad results for : 0 < x - y <= 2**lsbs_to_remove because truncate_bit_pattern(x - y, lsb) = 0
# option 2: y - x >= 0 => round_bit_pattern(y - x) >= 0

if auto_truncate is not None:
#print("Use truncate for <=")
return rounded_comparison(y, x, auto_truncate, operation=lambda x: x >= 0)
# numpy.less_equal(x, y) <= y is equivalent to :
# option 1: x - y <= 0 => round_bit_pattern(x - y + half) <= 0 or
# option 2: y - x >= 0 => round_bit_pattern(y - x - half) >= 0
# option 1: x - y <= 0 => truncate_bit_pattern(x - y + half) <= 0 or
# option 2: y - x >= 0 => truncate_bit_pattern(y - x - half) >= 0

# Option 2 is selected because it adheres to the established pattern in `rounded_comparison`
# which does: (a - b) - half.
# which does: (a - b).

if auto_truncate is not None:
return rounded_comparison(y, x, auto_truncate, operation=lambda x: x >= 0)

# Else, default numpy_less_or_equal operator
return numpy_less_or_equal(x, y)
Expand Down
22 changes: 10 additions & 12 deletions src/concrete/ml/sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,10 @@
# Most significant bits to retain when applying rounding to the tree
MSB_TO_KEEP_FOR_TREES = 1

# Enable rounding feature for all tree-based models by default
# Enable truncate feature for all tree-based models by default
# Note: This setting is fixed and cannot be altered by users
# However, for internal testing purposes, we retain the capability to disable this feature
os.environ["TREES_USE_ROUNDING"] = os.environ.get("TREES_USE_ROUNDING", "1")
os.environ["TREES_USE_TRUNCATE"] = os.environ.get("TREES_USE_TRUNCATE", "1")

# pylint: disable=too-many-public-methods

Expand Down Expand Up @@ -1326,8 +1326,8 @@ def fit(self, X: Data, y: Target, **fit_parameters):
self.output_quantizers = []

#: Determines the LSB to remove given a `target_msbs`
self.auto_truncate = cp.AutoTruncator(target_msbs=MSB_TO_KEEP_FOR_TREES)
self._auto_truncate = cp.AutoTruncator(target_msbs=MSB_TO_KEEP_FOR_TREES)

X, y = check_X_y_and_assert_multi_output(X, y)

q_X = numpy.zeros_like(X)
Expand All @@ -1353,25 +1353,23 @@ def fit(self, X: Data, y: Target, **fit_parameters):
assert self.sklearn_model is not None, self._sklearn_model_is_not_fitted_error_message()

# Enable optimized computation
enable_truncate = os.environ.get("TREES_USE_ROUNDING", "1") == "1"
enable_truncate = os.environ.get("TREES_USE_TRUNCATE", "1") == "1"

if not enable_truncate:
warnings.simplefilter("always")
warnings.warn(
"Using Concrete tree-based models without the `rounding feature` is deprecated. "
"Consider setting 'use_rounding' to `True` for making the FHE inference faster "
"Using Concrete tree-based models without the `truncate feature` is deprecated. "
"Consider setting 'use_truncate' to `True` for making the FHE inference faster "
"and key generation.",
category=DeprecationWarning,
stacklevel=2,
)
self.auto_truncate = None
self._auto_truncate = None

print(f"{self.auto_truncate=}")

self._tree_inference, self.output_quantizers, self.onnx_model_ = tree_to_numpy(
self.sklearn_model,
q_X,
auto_truncate=self.auto_truncate,
auto_truncate=self._auto_truncate,
fhe_ensembling=self._fhe_ensembling,
framework=self.framework,
output_n_bits=self.n_bits["op_leaves"],
Expand All @@ -1380,7 +1378,7 @@ def fit(self, X: Data, y: Target, **fit_parameters):
# Adjust the truncate
if enable_truncate:
inputset = numpy.array(list(_get_inputset_generator(q_X))).astype(int)
self.auto_truncate.adjust(self._tree_inference, inputset)
self._auto_truncate.adjust(self._tree_inference, inputset)
self._tree_inference(q_X.astype("int"))

self._is_fitted = True
Expand Down
15 changes: 3 additions & 12 deletions src/concrete/ml/sklearn/tree_to_numpy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Implements the conversion of a tree model to a numpy function."""
import math
import warnings
from typing import Callable, List, Optional, Tuple
from typing import Callable, List, Tuple

import numpy
import onnx
Expand Down Expand Up @@ -330,7 +330,7 @@ def tree_to_numpy(
model: Callable,
x: numpy.ndarray,
framework: str,
auto_truncate: bool = None,
auto_truncate=None,
fhe_ensembling: bool = False,
output_n_bits: int = MAX_BITWIDTH_BACKWARD_COMPATIBLE,
) -> Tuple[Callable, List[UniformQuantizer], onnx.ModelProto]:
Expand All @@ -339,7 +339,7 @@ def tree_to_numpy(
Args:
model (Callable): The tree model to convert.
x (numpy.ndarray): The input data.
auto_truncate (bool): Determines whether the rounding feature is enabled or disabled.
auto_truncate (TODO): Determines whether the rounding feature is enabled or disabled.
Default to True.
fhe_ensembling (bool): Determines whether the sum of the trees' outputs is computed in FHE.
Default to False.
Expand All @@ -365,15 +365,6 @@ def tree_to_numpy(
# Execute with 1 example for efficiency in large data scenarios to prevent slowdown
onnx_model = get_onnx_model(model, x[:1], framework)

# # Compute for tree-based models the LSB to remove in stage 1 and stage 2
# if use_rounding:
# # First LSB refers to Less or LessOrEqual comparisons
# # Second LSB refers to Equal comparison
# lsbs_to_remove_for_trees = _compute_lsb_to_remove_for_trees(onnx_model, x)

# # mypy
# assert len(lsbs_to_remove_for_trees) == 2

# Get the expected number of ONNX outputs in the sklearn model.
expected_number_of_outputs = 1 if is_regressor_or_partial_regressor(model) else 2

Expand Down
46 changes: 23 additions & 23 deletions tests/sklearn/test_sklearn_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1138,61 +1138,61 @@ def check_load_fitted_sklearn_linear_models(model_class, n_bits, x, y, check_flo
)


def check_rounding_consistency(
def check_truncate_consistency(
model,
x,
y,
predict_method,
metric,
is_weekly_option,
):
"""Test that Concrete ML without and with rounding are 'equivalent'."""
"""Test that Concrete ML without and with truncate are 'equivalent'."""

# Run the test with more samples during weekly CIs
if is_weekly_option:
fhe_test = get_random_samples(x, n_sample=5)

# Check that rounding is enabled
assert os.environ.get("TREES_USE_ROUNDING") == "1", "'TREES_USE_ROUNDING' is not enabled"
# Check that truncate is enabled
assert os.environ.get("TREES_USE_TRUNCATE") == "1", "'TREES_USE_TRUNCATE' is not enabled"

# Fit and compile with rounding enabled
# Fit and compile with truncate enabled
fit_and_compile(model, x, y)

rounded_predict_quantized = predict_method(x, fhe="disable")
rounded_predict_simulate = predict_method(x, fhe="simulate")
truncate_predict_quantized = predict_method(x, fhe="disable")
truncate_predict_simulate = predict_method(x, fhe="simulate")

# Compute the FHE predictions only during weekly CIs
if is_weekly_option:
rounded_predict_fhe = predict_method(fhe_test, fhe="execute")
truncate_predict_fhe = predict_method(fhe_test, fhe="execute")

with pytest.MonkeyPatch.context() as mp_context:

# Disable rounding
mp_context.setenv("TREES_USE_ROUNDING", "0")
# Disable truncate
mp_context.setenv("TREES_USE_TRUNCATE", "0")

# Check that rounding is disabled
assert os.environ.get("TREES_USE_ROUNDING") == "0", "'TREES_USE_ROUNDING' is not disabled"
# Check that truncate is disabled
assert os.environ.get("TREES_USE_TRUNCATE") == "0", "'TREES_USE_TRUNCATE' is not disabled"

with pytest.warns(
DeprecationWarning,
match=(
"Using Concrete tree-based models without the `rounding feature` is " "deprecated.*"
"Using Concrete tree-based models without the `truncate feature` is " "deprecated.*"
),
):

# Fit and compile without rounding
# Fit and compile without truncate
fit_and_compile(model, x, y)

not_rounded_predict_quantized = predict_method(x, fhe="disable")
not_rounded_predict_simulate = predict_method(x, fhe="simulate")
not_truncate_predict_quantized = predict_method(x, fhe="disable")
not_truncate_predict_simulate = predict_method(x, fhe="simulate")

metric(rounded_predict_quantized, not_rounded_predict_quantized)
metric(rounded_predict_simulate, not_rounded_predict_simulate)
metric(truncate_predict_quantized, not_truncate_predict_quantized)
metric(truncate_predict_simulate, not_truncate_predict_simulate)

# Compute the FHE predictions only during weekly CIs
if is_weekly_option:
not_rounded_predict_fhe = predict_method(fhe_test, fhe="execute")
metric(rounded_predict_fhe, not_rounded_predict_fhe)
not_truncate_predict_fhe = predict_method(fhe_test, fhe="execute")
metric(truncate_predict_fhe, not_truncate_predict_fhe)

# Check that the maximum bit-width of the circuit with rounding is at most:
# maximum bit-width (of the circuit without rounding) + 2
Expand Down Expand Up @@ -1881,7 +1881,7 @@ def test_linear_models_have_no_tlu(
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4179
@pytest.mark.parametrize("model_class, parameters", get_sklearn_tree_models_and_datasets())
@pytest.mark.parametrize("n_bits", [2, 5, 10])
def test_rounding_consistency_for_regular_models(
def test_truncate_consistency_for_regular_models(
model_class,
parameters,
n_bits,
Expand All @@ -1894,7 +1894,7 @@ def test_rounding_consistency_for_regular_models(
"""Test that Concrete ML without and with rounding are 'equivalent'."""

if verbose:
print("Run check_rounding_consistency")
print("Run check_truncate_consistency")

model = instantiate_model_generic(model_class, n_bits=n_bits)

Expand All @@ -1909,7 +1909,7 @@ def test_rounding_consistency_for_regular_models(
predict_method = model.predict
metric = check_accuracy

check_rounding_consistency(
check_truncate_consistency(
model,
x,
y,
Expand Down

0 comments on commit 4e2fcda

Please sign in to comment.