Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Knn classifier in cml #217

Merged
merged 51 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
0fc4ad8
chore: update base.py with concrete ml v
kcelia Jul 20, 2023
8dc0199
chore: v2
kcelia Jul 20, 2023
771648f
chore: keep one class
kcelia Jul 21, 2023
4fc02ea
chore: remove other classes
kcelia Jul 21, 2023
481950d
chore: update
kcelia Jul 25, 2023
b1ffecc
chore: version 1 w
kcelia Jul 26, 2023
af2550a
chore: previous version
kcelia Jul 26, 2023
96cd821
chore: start testing
kcelia Aug 27, 2023
98de388
chore: first testing version
kcelia Aug 28, 2023
795842e
chore: add `_NEIGHBORS_MODELS` and `get_sklearn_neighbors_models` to …
kcelia Aug 28, 2023
bab9ee7
chore: add a new inheritance layer for classification
kcelia Aug 28, 2023
cf76270
chore: update serialize testing
kcelia Sep 1, 2023
ab6f93d
chore: fix serialization test
kcelia Sep 1, 2023
9898062
chore: fix gridsearch test + make conformance
kcelia Sep 1, 2023
a13d028
chore: update conformance
kcelia Sep 1, 2023
ead5c45
chore: correct pairwise euclidean_distances
kcelia Sep 1, 2023
9336b84
chore: remove other classes
kcelia Sep 4, 2023
a574717
chore: fix make pcc
kcelia Sep 4, 2023
833d468
chore: update test/common
kcelia Sep 4, 2023
7493592
chore: fix parameter search tests
kcelia Sep 4, 2023
e06bd40
chore: fix deployment tests
kcelia Sep 4, 2023
7005354
chore: reduce dataset size for knn
kcelia Sep 4, 2023
538f03a
chore: add self._y
kcelia Sep 5, 2023
a7bab6d
chore: fix test_p_error_global_p_error_simulation test
kcelia Sep 5, 2023
7bcaddc
chore: fix test_quantization
kcelia Sep 5, 2023
20edf01
chore: add encrypted argsort
kcelia Sep 5, 2023
e9f2c21
chore: decrease even more the knn dataset size
kcelia Sep 5, 2023
600f72c
chore: correct argsort and topk_indice naming
kcelia Sep 5, 2023
654983d
chore: remove topk_indice
kcelia Sep 5, 2023
e014ff0
chore: simplify multiplication
kcelia Sep 6, 2023
ef24859
chore: fix inference
kcelia Sep 7, 2023
108f922
chore: remove dequantization for sortargmax
kcelia Sep 7, 2023
4978506
chore: reduce even more the dataset size of knn
kcelia Sep 7, 2023
fb633d3
chore: decrease the defaut n_bit of knn class to 4.
kcelia Sep 7, 2023
2cc6b26
chore: fix test_dump_onn
kcelia Sep 7, 2023
fd5ff58
chore: fix double_fit test for KNN
kcelia Sep 8, 2023
3faad7a
chore: fix tests/common and tests/deployment
kcelia Sep 11, 2023
c1ef09b
chore: fix parameter_search test
kcelia Sep 11, 2023
b6ec7fc
chore: fix test_mono_param_waraning
kcelia Sep 11, 2023
f158db7
chore: fix grid_search test
kcelia Sep 11, 2023
4ea78fa
chore: fix predict_correctness
kcelia Sep 11, 2023
41c7cc5
chore: fix check_fitted_compiled_error_raises
kcelia Sep 12, 2023
1797dcf
chore: update
kcelia Sep 12, 2023
e94026c
chore: fix bug in prediction + fix p_error_simulation test
kcelia Sep 13, 2023
aeb9196
chore: resume show_mlir
kcelia Sep 13, 2023
0794adc
chore: reduce knn dataset
kcelia Sep 18, 2023
ca03c3c
chore: update
kcelia Sep 18, 2023
9d0a4dd
chore: force the configuration of KNN to run under MONO settings
kcelia Sep 18, 2023
a59aa96
chore: predict returns the topk labels
kcelia Sep 20, 2023
d5b6e46
chore: update check_for_divergent_predictions test for KNN
kcelia Sep 20, 2023
fd2c1c7
chore: add post_processing
kcelia Sep 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from concrete.ml.sklearn.base import (
BaseTreeEstimatorMixin,
QuantizedTorchEstimatorMixin,
SklearnKNeighborsMixin,
SklearnLinearModelMixin,
)

Expand Down Expand Up @@ -482,7 +483,12 @@ def check_is_good_execution_for_cml_vs_circuit_impl(
else:
assert isinstance(
model,
(QuantizedTorchEstimatorMixin, BaseTreeEstimatorMixin, SklearnLinearModelMixin),
(
QuantizedTorchEstimatorMixin,
BaseTreeEstimatorMixin,
SklearnLinearModelMixin,
SklearnKNeighborsMixin,
),
)

if model._is_a_public_cml_model: # pylint: disable=protected-access
Expand All @@ -492,8 +498,14 @@ def check_is_good_execution_for_cml_vs_circuit_impl(
# tests), especially since these results are tested in other tests such as the
# `check_subfunctions_in_fhe`
if is_classifier_or_partial_classifier(model):
results_cnp_circuit = model.predict_proba(*inputs, fhe=fhe_mode)
results_model = model.predict_proba(*inputs, fhe="disable")
if isinstance(model, SklearnKNeighborsMixin):
kcelia marked this conversation as resolved.
Show resolved Hide resolved
# For KNN `predict_proba` is not supported for now
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3962
results_cnp_circuit = model.predict(*inputs, fhe=fhe_mode)
results_model = model.predict(*inputs, fhe="disable")
else:
results_cnp_circuit = model.predict_proba(*inputs, fhe=fhe_mode)
results_model = model.predict_proba(*inputs, fhe="disable")

else:
results_cnp_circuit = model.predict(*inputs, fhe=fhe_mode)
Expand Down
3 changes: 3 additions & 0 deletions src/concrete/ml/common/serialization/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ def _get_fully_qualified_name(object_class: Type) -> str:
"skorch.dataset.Dataset",
"skorch.dataset.ValidSplit",
"inspect._empty",
"sklearn.neighbors._classification.KNeighborsClassifier",
"sklearn.metrics._dist_metrics.EuclideanDistance",
"sklearn.neighbors._kd_tree.KDTree",
]
)

Expand Down
22 changes: 20 additions & 2 deletions src/concrete/ml/pytest/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
DecisionTreeRegressor,
ElasticNet,
GammaRegressor,
KNeighborsClassifier,
Lasso,
LinearRegression,
LinearSVC,
Expand Down Expand Up @@ -66,6 +67,7 @@
]

_classifier_models = [
KNeighborsClassifier,
DecisionTreeClassifier,
RandomForestClassifier,
XGBClassifier,
Expand Down Expand Up @@ -95,9 +97,25 @@
id=get_model_name(model),
)
for model in _classifier_models
if get_model_name(model) != "KNeighborsClassifier"
for n_classes in [2, 4]
] + [
pytest.param(
kcelia marked this conversation as resolved.
Show resolved Hide resolved
model,
{
"n_samples": 6,
"n_features": 2,
"n_classes": n_classes,
"n_informative": 2,
"n_redundant": 0,
},
id=get_model_name(model),
)
for model in [KNeighborsClassifier]
for n_classes in [2]
]


# Get the data-sets. The data generation is seeded in load_data.
# Only LinearRegression supports multi targets
# GammaRegressor, PoissonRegressor and TweedieRegressor only handle positive target values
Expand Down Expand Up @@ -141,8 +159,8 @@ def get_random_extract_of_sklearn_models_and_datasets():
unique_model_classes.append(m)

# To avoid to make mistakes and return empty list
assert len(sklearn_models_and_datasets) == 28
assert len(unique_model_classes) == 18
assert len(sklearn_models_and_datasets) == 29
assert len(unique_model_classes) == 19

return unique_model_classes

Expand Down
11 changes: 9 additions & 2 deletions src/concrete/ml/search_parameters/p_error_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,11 @@
from tqdm import tqdm

from ..common.utils import is_brevitas_model, is_model_class_in_a_list
from ..sklearn import get_sklearn_neural_net_models, get_sklearn_tree_models
from ..sklearn import (
get_sklearn_neighbors_models,
get_sklearn_neural_net_models,
get_sklearn_tree_models,
)
from ..torch.compile import compile_brevitas_qat_model, compile_torch_model


Expand Down Expand Up @@ -126,7 +130,10 @@ def compile_and_simulated_fhe_inference(
dequantized_output = quantized_module.forward(calibration_data, fhe="simulate")

elif is_model_class_in_a_list(
estimator, get_sklearn_neural_net_models() + get_sklearn_tree_models()
estimator,
get_sklearn_neural_net_models()
+ get_sklearn_tree_models()
+ get_sklearn_neighbors_models(),
):
if not estimator.is_fitted:
estimator.fit(calibration_data, ground_truth)
Expand Down
28 changes: 27 additions & 1 deletion src/concrete/ml/sklearn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,16 @@

from ..common.debugging.custom_assert import assert_true
from ..common.utils import is_classifier_or_partial_classifier, is_regressor_or_partial_regressor
from .base import _ALL_SKLEARN_MODELS, _LINEAR_MODELS, _NEURALNET_MODELS, _TREE_MODELS
from .base import (
_ALL_SKLEARN_MODELS,
_LINEAR_MODELS,
_NEIGHBORS_MODELS,
_NEURALNET_MODELS,
_TREE_MODELS,
)
from .glm import GammaRegressor, PoissonRegressor, TweedieRegressor
from .linear_model import ElasticNet, Lasso, LinearRegression, LogisticRegression, Ridge
from .neighbors import KNeighborsClassifier
from .qnn import NeuralNetClassifier, NeuralNetRegressor
from .rf import RandomForestClassifier, RandomForestRegressor
from .svm import LinearSVC, LinearSVR
Expand All @@ -31,6 +38,7 @@ def get_sklearn_models():
"linear": sorted(list(_LINEAR_MODELS), key=lambda m: m.__name__),
"tree": sorted(list(_TREE_MODELS), key=lambda m: m.__name__),
"neural_net": sorted(list(_NEURALNET_MODELS), key=lambda m: m.__name__),
"neighbors": sorted(list(_NEIGHBORS_MODELS), key=lambda m: m.__name__),
}
return ans

Expand Down Expand Up @@ -123,3 +131,21 @@ def get_sklearn_neural_net_models(
"""
prelist = get_sklearn_models()["neural_net"]
return _filter_models(prelist, classifier, regressor, str_in_class_name)


def get_sklearn_neighbors_models(
classifier: bool = True, regressor: bool = True, str_in_class_name: List[str] = None
):
"""Return the list of available neighbor models in Concrete ML.

Args:
classifier (bool): whether you want classifiers or not
regressor (bool): whether you want regressors or not
str_in_class_name (List[str]): if not None, only return models with the given string or
list of strings as a substring in their class name

Returns:
the lists of neighbor models in Concrete ML
"""
prelist = get_sklearn_models()["neighbors"]
return _filter_models(prelist, classifier, regressor, str_in_class_name)
Loading
Loading