Skip to content

Commit

Permalink
chore: remove manual set of multi and mono strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
RomanBredehoft committed Sep 22, 2023
1 parent afd049a commit 9de1b2d
Show file tree
Hide file tree
Showing 11 changed files with 7 additions and 158 deletions.
5 changes: 0 additions & 5 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import pytest
import torch
from concrete.fhe import Graph as CPGraph
from concrete.fhe import ParameterSelectionStrategy
from concrete.fhe.compilation import Circuit, Configuration
from concrete.fhe.mlir.utils import MAXIMUM_TLU_BIT_WIDTH
from sklearn.datasets import make_classification, make_regression
Expand Down Expand Up @@ -147,9 +146,6 @@ def pytest_sessionfinish(session: pytest.Session, exitstatus): # pylint: disabl
def default_configuration():
"""Return the default test compilation configuration."""

# Remove parameter_selection_strategy once it is set to multi-parameter in Concrete Python
# by default
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3860
# Parameter `enable_unsafe_features` and `use_insecure_key_cache` are needed in order to be
# able to cache generated keys through `insecure_key_cache_location`. As the name suggests,
# these parameters are unsafe and should only be used for debugging in development
Expand All @@ -158,7 +154,6 @@ def default_configuration():
enable_unsafe_features=True,
use_insecure_key_cache=True,
insecure_key_cache_location="ConcreteNumpyKeyCache",
parameter_selection_strategy=ParameterSelectionStrategy.MULTI,
)


Expand Down
39 changes: 2 additions & 37 deletions src/concrete/ml/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import enum
import string
import warnings
from functools import partial
from types import FunctionType
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
Expand Down Expand Up @@ -587,43 +586,9 @@ def all_values_are_of_dtype(*values: Any, dtypes: Union[str, List[str]]) -> bool
return all(_is_of_dtype(value, supported_dtypes) for value in values)


def set_multi_parameter_in_configuration(configuration: Optional[Configuration], **kwargs):
"""Build a Configuration instance with multi-parameter strategy, unless one is already given.
If the given Configuration instance is not None and the parameter strategy is set to MONO, a
warning is raised in order to make sure the user did it on purpose.
Args:
configuration (Optional[Configuration]): The configuration to consider.
**kwargs: Additional parameters to use for instantiating a new Configuration instance, if
configuration is None.
Returns:
configuration (Configuration): A configuration with multi-parameter strategy.
"""
assert (
"parameter_selection_strategy" not in kwargs
), "Please do not provide a parameter_selection_strategy parameter as it will be set to MULTI."
if configuration is None:
configuration = Configuration(
parameter_selection_strategy=ParameterSelectionStrategy.MULTI, **kwargs
)

elif configuration.parameter_selection_strategy == ParameterSelectionStrategy.MONO:
warnings.warn(
"Setting the parameter_selection_strategy to mono-parameter is not recommended as it "
"can deteriorate performances. If you set it voluntarily, this message can be ignored. "
"Else, please set parameter_selection_strategy to ParameterSelectionStrategy.MULTI "
"when initializing the Configuration instance.",
stacklevel=2,
)

return configuration


# Remove this function once Concrete Python fixes the multi-parameter bug with fully-leveled
# Remove this function once Concrete Python fixes the multi-parameter bug with KNN
# circuits
# TODO: https://github.com/zama-ai/concrete-ml-internal/issues/3862
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3978
def force_mono_parameter_in_configuration(configuration: Optional[Configuration], **kwargs):
"""Force configuration to mono-parameter strategy.
Expand Down
6 changes: 0 additions & 6 deletions src/concrete/ml/quantization/quantized_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
check_there_is_no_p_error_options_in_configuration,
generate_proxy_function,
manage_parameters_for_pbs_errors,
set_multi_parameter_in_configuration,
to_tuple,
)
from .base_quantized_op import ONNXOpInputOutputType, QuantizedOp
Expand Down Expand Up @@ -639,11 +638,6 @@ def compile(
# Find the right way to set parameters for compiler, depending on the way we want to default
p_error, global_p_error = manage_parameters_for_pbs_errors(p_error, global_p_error)

# Remove this function once the default strategy is set to multi-parameter in Concrete
# Python
# TODO: https://github.com/zama-ai/concrete-ml-internal/issues/3860
configuration = set_multi_parameter_in_configuration(configuration)

# Jit compiler is now deprecated and will soon be removed, it is thus forced to False
# by default
self.fhe_circuit = compiler.compile(
Expand Down
42 changes: 0 additions & 42 deletions src/concrete/ml/sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
force_mono_parameter_in_configuration,
generate_proxy_function,
manage_parameters_for_pbs_errors,
set_multi_parameter_in_configuration,
)
from ..onnx.convert import OPSET_VERSION_FOR_ONNX_EXPORT
from ..onnx.onnx_model_manipulations import clean_graph_after_node_op_type, remove_node_types
Expand Down Expand Up @@ -1363,27 +1362,6 @@ def _get_module_to_compile(self) -> Union[Compiler, QuantizedModule]:
return compiler

def compile(self, *args, **kwargs) -> Circuit:

# Factorize this in the base class once Concrete Python fixes the multi-parameter bug
# with fully-leveled circuits
# TODO: https://github.com/zama-ai/concrete-ml-internal/issues/3862
# Remove this function once the default strategy is set to multi-parameter in Concrete
# Python
# TODO: https://github.com/zama-ai/concrete-ml-internal/issues/3860
# If a configuration instance is given as a positional parameter, set the strategy to
# multi-parameter
if len(args) >= 2:
configuration = set_multi_parameter_in_configuration(args[1])
args_list = list(args)
args_list[1] = configuration
args = tuple(args_list)

# Else, retrieve the configuration in kwargs if it exists, or create a new one, and set the
# strategy to multi-parameter
else:
configuration = kwargs.get("configuration", None)
kwargs["configuration"] = set_multi_parameter_in_configuration(configuration)

BaseEstimator.compile(self, *args, **kwargs)

# Check that the graph only has a single output
Expand Down Expand Up @@ -1638,26 +1616,6 @@ def _inference(self, q_X: numpy.ndarray) -> numpy.ndarray:
y_pred += self._q_bias
return y_pred

# Remove this function once Concrete Python fixes the multi-parameter bug with fully-leveled
# circuits and factorize it in the base class
# TODO: https://github.com/zama-ai/concrete-ml-internal/issues/3862
def compile(self, *args, **kwargs) -> Circuit:
# If a configuration instance is given as a positional parameter, set the strategy to
# multi-parameter
if len(args) >= 2:
configuration = force_mono_parameter_in_configuration(args[1])
args_list = list(args)
args_list[1] = configuration
args = tuple(args_list)

# Else, retrieve the configuration in kwargs if it exists, or create a new one, and set the
# strategy to multi-parameter
else:
configuration = kwargs.get("configuration", None)
kwargs["configuration"] = force_mono_parameter_in_configuration(configuration)

return BaseEstimator.compile(self, *args, **kwargs)


class SklearnLinearRegressorMixin(SklearnLinearModelMixin, sklearn.base.RegressorMixin, ABC):
"""A Mixin class for sklearn linear regressors with FHE.
Expand Down
46 changes: 0 additions & 46 deletions tests/sklearn/test_sklearn_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
import pandas
import pytest
import torch
from concrete.fhe import ParameterSelectionStrategy
from sklearn.decomposition import PCA
from sklearn.exceptions import ConvergenceWarning, UndefinedMetricWarning
from sklearn.metrics import make_scorer, matthews_corrcoef, top_k_accuracy_score
Expand Down Expand Up @@ -1056,19 +1055,6 @@ def check_exposition_structural_methods_decision_trees(model, x, y):
)


def check_mono_parameter_warnings(model, x, default_configuration):
"""Check that setting voluntarily a mono-parameter strategy properly raises a warning."""

# Set the parameter strategy to mono-parameter
default_configuration.parameter_selection_strategy = ParameterSelectionStrategy.MONO

with pytest.warns(
UserWarning,
match="Setting the parameter_selection_strategy to mono-parameter is not recommended.*",
):
model.compile(x, default_configuration)


@pytest.mark.parametrize("model_class, parameters", sklearn_models_and_datasets)
@pytest.mark.parametrize(
"n_bits",
Expand Down Expand Up @@ -1668,35 +1654,3 @@ def test_exposition_structural_methods_decision_trees(
print("Run check_exposition_structural_methods_decision_trees")

check_exposition_structural_methods_decision_trees(model, x, y)


@pytest.mark.parametrize("model_class, parameters", sklearn_models_and_datasets)
def test_mono_parameter_warnings(
model_class,
parameters,
load_data,
is_weekly_option,
default_configuration,
verbose=True,
):
"""Test that setting voluntarily a mono-parameter strategy properly raises a warning."""

# Remove this once Concrete Python fixes the multi-parameter bug with fully-leveled circuits
# TODO: https://github.com/zama-ai/concrete-ml-internal/issues/3862
# Linear models are manually forced to use mono-parameter
if is_model_class_in_a_list(model_class, get_sklearn_linear_models()):
return

# KNN is manually forced to use mono-parameter
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3978
if is_model_class_in_a_list(model_class, get_sklearn_neighbors_models()):
return

n_bits = min(N_BITS_REGULAR_BUILDS)

model, x = preamble(model_class, parameters, n_bits, load_data, is_weekly_option)

if verbose:
print("Run check_mono_parameter_warnings")

check_mono_parameter_warnings(model, x, default_configuration)
3 changes: 1 addition & 2 deletions tests/torch/test_hybrid_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import pytest
import torch
from concrete.fhe import Configuration, ParameterSelectionStrategy
from concrete.fhe import Configuration
from transformers import GPT2LMHeadModel, GPT2Tokenizer

from concrete.ml.torch.hybrid_model import HybridFHEModel
Expand All @@ -20,7 +20,6 @@ def run_hybrid_model_test(
# Multi-parameter strategy is used in order to speed-up the FHE executions
configuration = Configuration(
single_precision=False,
parameter_selection_strategy=ParameterSelectionStrategy.MULTI,
)

# Create a hybrid model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def wrapper(*args, **kwargs):
enable_unsafe_features=True,
use_insecure_key_cache=True,
insecure_key_cache_location=KEYGEN_CACHE_DIR,
parameter_selection_strategy=fhe.ParameterSelectionStrategy.MULTI,
)

print("Compiling the model.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np
import torch
from concrete.fhe import Configuration, ParameterSelectionStrategy
from concrete.fhe import Configuration
from models import cnv_2w2a
from torch.utils.data import DataLoader
from tqdm import tqdm
Expand Down Expand Up @@ -106,7 +106,6 @@ def main(args):
cfg = Configuration(
verbose=True,
show_optimizer=args.show_optimizer,
parameter_selection_strategy=ParameterSelectionStrategy.MULTI,
)

for rounding_threshold_bits in rounding_threshold_bits_list:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch
import torchvision
import torchvision.transforms as transforms
from concrete.fhe import Circuit, Configuration, ParameterSelectionStrategy
from concrete.fhe import Circuit, Configuration
from model import CNV

from concrete.ml.deployment.fhe_client_server import FHEModelDev
Expand Down Expand Up @@ -54,10 +54,7 @@ def main():
train_features_sub_set = model.clear_module(train_sub_set)

# Multi-parameter strategy is used in order to speed-up the FHE executions
configuration = Configuration(
show_optimizer=True,
parameter_selection_strategy=ParameterSelectionStrategy.MULTI,
)
configuration = Configuration(show_optimizer=True)

compilation_onnx_path = "compilation_model.onnx"
print("Compiling the model ...")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import torch
import torchvision
from brevitas import config
from concrete.fhe import Configuration, ParameterSelectionStrategy
from model import CNV
from scipy.special import softmax
from torch.backends import cudnn
Expand Down Expand Up @@ -110,14 +109,6 @@ def main():
with torch.no_grad():
train_features_sub_set = net.clear_module(train_sub_set)

optional_kwargs = {}

# Multi-parameter strategy is used in order to speed-up the FHE executions
optional_kwargs["configuration"] = Configuration(
dump_artifacts_on_unexpected_failures=True,
parameter_selection_strategy=ParameterSelectionStrategy.MULTI,
)

compilation_onnx_path = "compilation_model.onnx"
print("Compiling the model")
start_compile = time.time()
Expand All @@ -126,7 +117,6 @@ def main():
quantized_numpy_module = compile_brevitas_qat_model(
torch_model=net.encrypted_module,
torch_inputset=train_features_sub_set,
**optional_kwargs,
output_onnx_file=compilation_onnx_path,
)

Expand Down
3 changes: 1 addition & 2 deletions use_case_examples/mnist/MnistInFHE.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"import torch\n",
"\n",
"# Concrete-Python\n",
"from concrete.fhe import Configuration, ParameterSelectionStrategy\n",
"from concrete.fhe import Configuration\n",
"\n",
"# The QAT model\n",
"from model import MNISTQATModel # pylint: disable=no-name-in-module\n",
Expand Down Expand Up @@ -168,7 +168,6 @@
" enable_unsafe_features=True,\n",
" use_insecure_key_cache=True,\n",
" insecure_key_cache_location=\"/tmp/keycache\",\n",
" parameter_selection_strategy=ParameterSelectionStrategy.MULTI,\n",
" )\n",
"\n",
" if use_simulation:\n",
Expand Down

0 comments on commit 9de1b2d

Please sign in to comment.