Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: remove manual set of multi and mono strategy #256

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import pytest
import torch
from concrete.fhe import Graph as CPGraph
from concrete.fhe import ParameterSelectionStrategy
from concrete.fhe.compilation import Circuit, Configuration
from concrete.fhe.mlir.utils import MAXIMUM_TLU_BIT_WIDTH
from sklearn.datasets import make_classification, make_regression
Expand Down Expand Up @@ -147,9 +146,6 @@ def pytest_sessionfinish(session: pytest.Session, exitstatus): # pylint: disabl
def default_configuration():
"""Return the default test compilation configuration."""

# Remove parameter_selection_strategy once it is set to multi-parameter in Concrete Python
# by default
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3860
# Parameter `enable_unsafe_features` and `use_insecure_key_cache` are needed in order to be
# able to cache generated keys through `insecure_key_cache_location`. As the name suggests,
# these parameters are unsafe and should only be used for debugging in development
Expand All @@ -158,7 +154,6 @@ def default_configuration():
enable_unsafe_features=True,
use_insecure_key_cache=True,
insecure_key_cache_location="ConcreteNumpyKeyCache",
parameter_selection_strategy=ParameterSelectionStrategy.MULTI,
)


Expand Down
39 changes: 2 additions & 37 deletions src/concrete/ml/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import enum
import string
import warnings
from functools import partial
from types import FunctionType
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
Expand Down Expand Up @@ -587,43 +586,9 @@ def all_values_are_of_dtype(*values: Any, dtypes: Union[str, List[str]]) -> bool
return all(_is_of_dtype(value, supported_dtypes) for value in values)


def set_multi_parameter_in_configuration(configuration: Optional[Configuration], **kwargs):
"""Build a Configuration instance with multi-parameter strategy, unless one is already given.

If the given Configuration instance is not None and the parameter strategy is set to MONO, a
warning is raised in order to make sure the user did it on purpose.

Args:
configuration (Optional[Configuration]): The configuration to consider.
**kwargs: Additional parameters to use for instantiating a new Configuration instance, if
configuration is None.

Returns:
configuration (Configuration): A configuration with multi-parameter strategy.
"""
assert (
"parameter_selection_strategy" not in kwargs
), "Please do not provide a parameter_selection_strategy parameter as it will be set to MULTI."
if configuration is None:
configuration = Configuration(
parameter_selection_strategy=ParameterSelectionStrategy.MULTI, **kwargs
)

elif configuration.parameter_selection_strategy == ParameterSelectionStrategy.MONO:
warnings.warn(
"Setting the parameter_selection_strategy to mono-parameter is not recommended as it "
"can deteriorate performances. If you set it voluntarily, this message can be ignored. "
"Else, please set parameter_selection_strategy to ParameterSelectionStrategy.MULTI "
"when initializing the Configuration instance.",
stacklevel=2,
)

return configuration


# Remove this function once Concrete Python fixes the multi-parameter bug with fully-leveled
# Remove this function once Concrete Python fixes the multi-parameter bug with KNN
# circuits
# TODO: https://github.com/zama-ai/concrete-ml-internal/issues/3862
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3978
def force_mono_parameter_in_configuration(configuration: Optional[Configuration], **kwargs):
"""Force configuration to mono-parameter strategy.

Expand Down
6 changes: 0 additions & 6 deletions src/concrete/ml/quantization/quantized_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
check_there_is_no_p_error_options_in_configuration,
generate_proxy_function,
manage_parameters_for_pbs_errors,
set_multi_parameter_in_configuration,
to_tuple,
)
from .base_quantized_op import ONNXOpInputOutputType, QuantizedOp
Expand Down Expand Up @@ -639,11 +638,6 @@ def compile(
# Find the right way to set parameters for compiler, depending on the way we want to default
p_error, global_p_error = manage_parameters_for_pbs_errors(p_error, global_p_error)

# Remove this function once the default strategy is set to multi-parameter in Concrete
# Python
# TODO: https://github.com/zama-ai/concrete-ml-internal/issues/3860
configuration = set_multi_parameter_in_configuration(configuration)

# Jit compiler is now deprecated and will soon be removed, it is thus forced to False
# by default
self.fhe_circuit = compiler.compile(
Expand Down
42 changes: 0 additions & 42 deletions src/concrete/ml/sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
force_mono_parameter_in_configuration,
generate_proxy_function,
manage_parameters_for_pbs_errors,
set_multi_parameter_in_configuration,
)
from ..onnx.convert import OPSET_VERSION_FOR_ONNX_EXPORT
from ..onnx.onnx_model_manipulations import clean_graph_after_node_op_type, remove_node_types
Expand Down Expand Up @@ -1363,27 +1362,6 @@ def _get_module_to_compile(self) -> Union[Compiler, QuantizedModule]:
return compiler

def compile(self, *args, **kwargs) -> Circuit:

# Factorize this in the base class once Concrete Python fixes the multi-parameter bug
# with fully-leveled circuits
# TODO: https://github.com/zama-ai/concrete-ml-internal/issues/3862
# Remove this function once the default strategy is set to multi-parameter in Concrete
# Python
# TODO: https://github.com/zama-ai/concrete-ml-internal/issues/3860
# If a configuration instance is given as a positional parameter, set the strategy to
# multi-parameter
if len(args) >= 2:
configuration = set_multi_parameter_in_configuration(args[1])
args_list = list(args)
args_list[1] = configuration
args = tuple(args_list)

# Else, retrieve the configuration in kwargs if it exists, or create a new one, and set the
# strategy to multi-parameter
else:
configuration = kwargs.get("configuration", None)
kwargs["configuration"] = set_multi_parameter_in_configuration(configuration)

BaseEstimator.compile(self, *args, **kwargs)

# Check that the graph only has a single output
Expand Down Expand Up @@ -1638,26 +1616,6 @@ def _inference(self, q_X: numpy.ndarray) -> numpy.ndarray:
y_pred += self._q_bias
return y_pred

# Remove this function once Concrete Python fixes the multi-parameter bug with fully-leveled
# circuits and factorize it in the base class
# TODO: https://github.com/zama-ai/concrete-ml-internal/issues/3862
def compile(self, *args, **kwargs) -> Circuit:
# If a configuration instance is given as a positional parameter, set the strategy to
# multi-parameter
if len(args) >= 2:
configuration = force_mono_parameter_in_configuration(args[1])
args_list = list(args)
args_list[1] = configuration
args = tuple(args_list)

# Else, retrieve the configuration in kwargs if it exists, or create a new one, and set the
# strategy to multi-parameter
else:
configuration = kwargs.get("configuration", None)
kwargs["configuration"] = force_mono_parameter_in_configuration(configuration)

return BaseEstimator.compile(self, *args, **kwargs)


class SklearnLinearRegressorMixin(SklearnLinearModelMixin, sklearn.base.RegressorMixin, ABC):
"""A Mixin class for sklearn linear regressors with FHE.
Expand Down
46 changes: 0 additions & 46 deletions tests/sklearn/test_sklearn_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
import pandas
import pytest
import torch
from concrete.fhe import ParameterSelectionStrategy
from sklearn.decomposition import PCA
from sklearn.exceptions import ConvergenceWarning, UndefinedMetricWarning
from sklearn.metrics import make_scorer, matthews_corrcoef, top_k_accuracy_score
Expand Down Expand Up @@ -1056,19 +1055,6 @@ def check_exposition_structural_methods_decision_trees(model, x, y):
)


def check_mono_parameter_warnings(model, x, default_configuration):
"""Check that setting voluntarily a mono-parameter strategy properly raises a warning."""

# Set the parameter strategy to mono-parameter
default_configuration.parameter_selection_strategy = ParameterSelectionStrategy.MONO

with pytest.warns(
UserWarning,
match="Setting the parameter_selection_strategy to mono-parameter is not recommended.*",
):
model.compile(x, default_configuration)


@pytest.mark.parametrize("model_class, parameters", sklearn_models_and_datasets)
@pytest.mark.parametrize(
"n_bits",
Expand Down Expand Up @@ -1668,35 +1654,3 @@ def test_exposition_structural_methods_decision_trees(
print("Run check_exposition_structural_methods_decision_trees")

check_exposition_structural_methods_decision_trees(model, x, y)


@pytest.mark.parametrize("model_class, parameters", sklearn_models_and_datasets)
def test_mono_parameter_warnings(
model_class,
parameters,
load_data,
is_weekly_option,
default_configuration,
verbose=True,
):
"""Test that setting voluntarily a mono-parameter strategy properly raises a warning."""

# Remove this once Concrete Python fixes the multi-parameter bug with fully-leveled circuits
# TODO: https://github.com/zama-ai/concrete-ml-internal/issues/3862
# Linear models are manually forced to use mono-parameter
if is_model_class_in_a_list(model_class, get_sklearn_linear_models()):
return

# KNN is manually forced to use mono-parameter
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3978
if is_model_class_in_a_list(model_class, get_sklearn_neighbors_models()):
return

n_bits = min(N_BITS_REGULAR_BUILDS)

model, x = preamble(model_class, parameters, n_bits, load_data, is_weekly_option)

if verbose:
print("Run check_mono_parameter_warnings")

check_mono_parameter_warnings(model, x, default_configuration)
3 changes: 1 addition & 2 deletions tests/torch/test_hybrid_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import pytest
import torch
from concrete.fhe import Configuration, ParameterSelectionStrategy
from concrete.fhe import Configuration
from transformers import GPT2LMHeadModel, GPT2Tokenizer

from concrete.ml.torch.hybrid_model import HybridFHEModel
Expand All @@ -20,7 +20,6 @@ def run_hybrid_model_test(
# Multi-parameter strategy is used in order to speed-up the FHE executions
configuration = Configuration(
single_precision=False,
parameter_selection_strategy=ParameterSelectionStrategy.MULTI,
)

# Create a hybrid model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def wrapper(*args, **kwargs):
enable_unsafe_features=True,
use_insecure_key_cache=True,
insecure_key_cache_location=KEYGEN_CACHE_DIR,
parameter_selection_strategy=fhe.ParameterSelectionStrategy.MULTI,
)

print("Compiling the model.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np
import torch
from concrete.fhe import Configuration, ParameterSelectionStrategy
from concrete.fhe import Configuration
from models import cnv_2w2a
from torch.utils.data import DataLoader
from tqdm import tqdm
Expand Down Expand Up @@ -106,7 +106,6 @@ def main(args):
cfg = Configuration(
verbose=True,
show_optimizer=args.show_optimizer,
parameter_selection_strategy=ParameterSelectionStrategy.MULTI,
)

for rounding_threshold_bits in rounding_threshold_bits_list:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch
import torchvision
import torchvision.transforms as transforms
from concrete.fhe import Circuit, Configuration, ParameterSelectionStrategy
from concrete.fhe import Circuit, Configuration
from model import CNV

from concrete.ml.deployment.fhe_client_server import FHEModelDev
Expand Down Expand Up @@ -54,10 +54,7 @@ def main():
train_features_sub_set = model.clear_module(train_sub_set)

# Multi-parameter strategy is used in order to speed-up the FHE executions
configuration = Configuration(
show_optimizer=True,
parameter_selection_strategy=ParameterSelectionStrategy.MULTI,
)
configuration = Configuration(show_optimizer=True)

compilation_onnx_path = "compilation_model.onnx"
print("Compiling the model ...")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import torch
import torchvision
from brevitas import config
from concrete.fhe import Configuration, ParameterSelectionStrategy
from model import CNV
from scipy.special import softmax
from torch.backends import cudnn
Expand Down Expand Up @@ -110,14 +109,6 @@ def main():
with torch.no_grad():
train_features_sub_set = net.clear_module(train_sub_set)

optional_kwargs = {}

# Multi-parameter strategy is used in order to speed-up the FHE executions
optional_kwargs["configuration"] = Configuration(
dump_artifacts_on_unexpected_failures=True,
parameter_selection_strategy=ParameterSelectionStrategy.MULTI,
)

compilation_onnx_path = "compilation_model.onnx"
print("Compiling the model")
start_compile = time.time()
Expand All @@ -126,7 +117,6 @@ def main():
quantized_numpy_module = compile_brevitas_qat_model(
torch_model=net.encrypted_module,
torch_inputset=train_features_sub_set,
**optional_kwargs,
output_onnx_file=compilation_onnx_path,
)

Expand Down
3 changes: 1 addition & 2 deletions use_case_examples/mnist/MnistInFHE.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"import torch\n",
"\n",
"# Concrete-Python\n",
"from concrete.fhe import Configuration, ParameterSelectionStrategy\n",
"from concrete.fhe import Configuration\n",
"\n",
"# The QAT model\n",
"from model import MNISTQATModel # pylint: disable=no-name-in-module\n",
Expand Down Expand Up @@ -168,7 +168,6 @@
" enable_unsafe_features=True,\n",
" use_insecure_key_cache=True,\n",
" insecure_key_cache_location=\"/tmp/keycache\",\n",
" parameter_selection_strategy=ParameterSelectionStrategy.MULTI,\n",
" )\n",
"\n",
" if use_simulation:\n",
Expand Down