Skip to content

Commit

Permalink
chore: update
Browse files Browse the repository at this point in the history
  • Loading branch information
kcelia committed Jan 15, 2024
1 parent f8dccfa commit 99e34d4
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 57 deletions.
17 changes: 10 additions & 7 deletions src/concrete/ml/sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1430,16 +1430,19 @@ def predict(self, X: Data, fhe: Union[FheMode, str] = FheMode.DISABLE) -> numpy.
y_pred = self.post_processing(y_pred)
return y_pred

# def post_processing(self, y_preds: numpy.ndarray) -> numpy.ndarray:
def post_processing(self, y_preds: numpy.ndarray) -> numpy.ndarray:

# Sum all tree outputs
# Remove the sum once we handle multi-precision circuits
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/451
if os.getenv("TREES_USE_FHE_SUM") == "0":
y_preds = numpy.sum(y_preds, axis=-1)

assert_true(y_preds.ndim == 2, "y_preds should be a 2D array")
return y_preds

# # Sum all tree outputs
# # Remove the sum once we handle multi-precision circuits
# # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/451
# y_preds = numpy.sum(y_preds, axis=-1)
return super().post_processing(y_preds)

# assert_true(y_preds.ndim == 2, "y_preds should be a 2D array")
# return y_preds

class BaseTreeRegressorMixin(BaseTreeEstimatorMixin, sklearn.base.RegressorMixin, ABC):
"""Mixin class for tree-based regressors.
Expand Down
18 changes: 13 additions & 5 deletions src/concrete/ml/sklearn/tree_to_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
OPSET_VERSION_FOR_ONNX_EXPORT,
get_equivalent_numpy_forward_from_onnx_tree,
)
from ..onnx.onnx_model_manipulations import clean_graph_after_node_op_type, remove_node_types
from ..onnx.onnx_model_manipulations import (
clean_graph_after_node_op_type,
clean_graph_at_node_op_type,
remove_node_types,
)
from ..onnx.onnx_utils import get_op_type
from ..quantization import QuantizedArray
from ..quantization.quantizers import UniformQuantizer
Expand Down Expand Up @@ -142,13 +146,14 @@ def add_transpose_after_last_node(onnx_model: onnx.ModelProto):
# Get the output node
output_node = onnx_model.graph.output[0]

# When using FHE sum for tree ensembles, create the node with perm attribute equal to (1, 0)
if os.getenv("TREES_USE_FHE_SUM") == "1":
# Create the node with perm attribute equal to (1, 0)
perm = [1, 0]

# Otherwise, create the node with perm attribute equal to (2, 1, 0)
else:
# Create the node with perm attribute equal to (2, 1, 0)
perm = [2, 1, 0]

transpose_node = onnx.helper.make_node(
"Transpose",
inputs=[output_node.name],
Expand Down Expand Up @@ -246,7 +251,10 @@ def tree_onnx_graph_preprocessing(

# Cut the graph after the ReduceSum node to remove
# argmax, sigmoid, softmax from the graph.
clean_graph_after_node_op_type(onnx_model, "ReduceSum")
if os.getenv("TREES_USE_FHE_SUM") == "1":
clean_graph_after_node_op_type(onnx_model, "ReduceSum")
else:
clean_graph_at_node_op_type(onnx_model, "ReduceSum")

if framework == "xgboost":
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/2778
Expand Down
15 changes: 6 additions & 9 deletions tests/sklearn/test_dump_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def check_onnx_file_dump(model_class, parameters, load_data, str_expected, defau
str_model = onnx.helper.printable_graph(onnx_model.graph)
print(f"{model_name}:")
print(str_model)

# Test equality when it does not depend on seeds
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3266
if not is_model_class_in_a_list(model_class, _get_sklearn_tree_models(select="RandomForest")):
Expand Down Expand Up @@ -228,7 +227,7 @@ def test_dump(
%transposed_output = Transpose[perm = [1, 0]](%/_operators.0/ReduceSum_output_0)
"""
if os.getenv("TREES_USE_FHE_SUM") == "1"
else "%transposed_output = Transpose[perm = [2, 1, 0]](%/_operators.0/Reshape_3_output_0)"
else "%transposed_output = Transpose[perm = [2, 1, 0]](%/_operators.0/Reshape_3_output_0)\n "
)
+ """return %transposed_output
}""",
Expand Down Expand Up @@ -307,7 +306,7 @@ def test_dump(
%transposed_output = Transpose[perm = [1, 0]](%/_operators.0/ReduceSum_output_0)
"""
if os.getenv("TREES_USE_FHE_SUM") == "1"
else ""
else "%transposed_output = Transpose[perm = [2, 1, 0]](%/_operators.0/Reshape_3_output_0)\n "
)
+ """return %transposed_output
}""",
Expand Down Expand Up @@ -359,8 +358,7 @@ def test_dump(
return %/_operators.0/ReduceSum_output_0
}"""
if os.getenv("TREES_USE_FHE_SUM") == "1"
else """return %/_operators.0/Reshape_4_output_0
}"""
else "return %/_operators.0/Reshape_4_output_0\n}"
),
"RandomForestRegressor": """graph torch_jit (
%input_0[DOUBLE, symx10]
Expand Down Expand Up @@ -401,9 +399,8 @@ def test_dump(
%/_operators.0/Constant_1_output_0[INT64, 2]
%/_operators.0/Constant_2_output_0[INT64, 3]
%/_operators.0/Constant_3_output_0[INT64, 3]
%/_operators.0/Constant_4_output_0[INT64, 3]
"""
+ ("%onnx::ReduceSum_27[INT64, 1]" if os.getenv("TREES_USE_FHE_SUM") == "1" else "")
%/_operators.0/Constant_4_output_0[INT64, 3]"""
+ ("\n %onnx::ReduceSum_27[INT64, 1]" if os.getenv("TREES_USE_FHE_SUM") == "1" else "")
+ """
) {
%/_operators.0/Gemm_output_0 = Gemm[alpha = 1, beta = 0, transB = 1](%_operators.0.weight_1, %input_0)
Expand All @@ -424,7 +421,7 @@ def test_dump(
return %/_operators.0/ReduceSum_output_0
}"""
if os.getenv("TREES_USE_FHE_SUM") == "1"
else "return %/_operators.0/Reshape_4_output_0"
else """return %/_operators.0/Reshape_4_output_0\n}"""
),
"LinearRegression": """graph torch_jit (
%input_0[DOUBLE, symx10]
Expand Down
85 changes: 49 additions & 36 deletions tests/sklearn/test_sklearn_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from concrete.ml.common.serialization.loaders import load, loads
from concrete.ml.common.utils import (
USE_OLD_VL,
array_allclose_and_same_shape,
get_model_class,
get_model_name,
is_classifier_or_partial_classifier,
Expand Down Expand Up @@ -1206,47 +1207,73 @@ def check_rounding_consistency(


def check_fhe_sum_consistency(
model_class,
x,
predict_method,
metric,
y,
n_bits,
is_weekly_option,
):
"""Test that Concrete ML without and with rounding are 'equivalent'."""
"""Test that Concrete ML without and with FHE sum are 'equivalent'."""

# Run the test with more samples during weekly CIs
if is_weekly_option:
fhe_test = get_random_samples(x, n_sample=5)

# By default, FHE_SUM is disabled
fhe_sum_disabled = os.getenv("TREES_USE_FHE_SUM") == "1"
# By default, the summation of tree ensemble outputs is done in clear
fhe_sum_disabled = os.getenv("TREES_USE_FHE_SUM") == "0"
assert fhe_sum_disabled

model_ref = instantiate_model_generic(model_class, n_bits=n_bits)
fit_and_compile(model_ref, x, y)

# Check `predict_proba` for classifiers and `predict` for regressors
predict_method = (
model_ref.predict_proba
if is_classifier_or_partial_classifier(model_class)
else model_ref.predict
)

non_fhe_sum_predict_quantized = predict_method(x, fhe="disable")
non_fhe_sum_predict_simulate = predict_method(x, fhe="simulate")

# Sanity check
array_allclose_and_same_shape(non_fhe_sum_predict_quantized, non_fhe_sum_predict_simulate)

# Compute the FHE predictions only during weekly CIs
if is_weekly_option:
rounded_predict_fhe = predict_method(fhe_test, fhe="execute")
non_fhe_sum_predict_fhe = predict_method(fhe_test, fhe="execute")

with pytest.MonkeyPatch.context() as mp_context:

# Enable FHE sum
mp_context.setenv("TREES_USE_FHE_SUM", "0")
# Enable the FHE summation of tree ensemble outputs
mp_context.setenv("TREES_USE_FHE_SUM", "1")

# Check that rounding is disabled
fhe_sum_enbled = os.environ.get("TREES_USE_FHE_SUM") == "0"
assert fhe_sum_enbled
# Check that the summation of tree ensemble outputs is enabled
fhe_sum_enabled = os.environ.get("TREES_USE_FHE_SUM") == "1"
assert fhe_sum_enabled

model = model_class(**model_ref.get_params())
fit_and_compile(model, x, y)

# Check `predict_proba` for classifiers and `predict` for regressors
predict_method = (
model.predict_proba
if is_classifier_or_partial_classifier(model_class)
else model.predict
)

fhe_sum_predict_quantized = predict_method(x, fhe="disable")
fhe_sum_predict_simulate = predict_method(x, fhe="simulate")

metric(non_fhe_sum_predict_quantized, fhe_sum_predict_quantized)
metric(non_fhe_sum_predict_simulate, fhe_sum_predict_simulate)
# Sanity check
array_allclose_and_same_shape(fhe_sum_predict_quantized, fhe_sum_predict_simulate)

# Compute the FHE predictions only during weekly CIs
if is_weekly_option:
not_rounded_predict_fhe = predict_method(fhe_test, fhe="execute")
metric(rounded_predict_fhe, not_rounded_predict_fhe)
# Check that we have the exact same predictions
array_allclose_and_same_shape(fhe_sum_predict_quantized, non_fhe_sum_predict_quantized)
array_allclose_and_same_shape(fhe_sum_predict_simulate, non_fhe_sum_predict_simulate)
if is_weekly_option:
fhe_sum_predict_fhe = predict_method(fhe_test, fhe="execute")
array_allclose_and_same_shape(fhe_sum_predict_fhe, non_fhe_sum_predict_fhe)


# Neural network models are skipped for this test
Expand Down Expand Up @@ -1937,34 +1964,20 @@ def test_fhe_sum_for_tree_based_models(
parameters,
n_bits,
load_data,
check_r2_score,
check_accuracy,
is_weekly_option,
default_configuration,
verbose=True,
):
"""Test that Concrete ML without and with rounding are 'equivalent'."""

if verbose:
print("Run check_rounding_consistency")
print("Run check_fhe_sum_consistency")

model, x = preamble(model_class, parameters, n_bits, load_data, is_weekly_option)

# Compile the model to make sure we consider all possible attributes during the serialization
model.compile(x, default_configuration)

# Check `predict_proba` for classifiers
if is_classifier_or_partial_classifier(model):
predict_method = model.predict_proba
metric = check_r2_score
else:
# Check `predict` for regressors
predict_method = model.predict
metric = check_accuracy
x, y = get_dataset(model_class, parameters, n_bits, load_data, is_weekly_option)

check_fhe_sum_consistency(
model_class,
x,
predict_method,
metric,
y,
n_bits,
is_weekly_option,
)

0 comments on commit 99e34d4

Please sign in to comment.