Skip to content

Commit

Permalink
feat: Add SGD regressor
Browse files Browse the repository at this point in the history
  • Loading branch information
andrei-stoian-zama authored Oct 23, 2023
1 parent d3b5060 commit abb143c
Show file tree
Hide file tree
Showing 15 changed files with 1,063 additions and 249 deletions.
2 changes: 2 additions & 0 deletions benchmarks/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
"NeuralNetRegressor",
"RandomForestRegressor",
"XGBRegressor",
"SGDRegressor",
]
for model_name in REGRESSORS_NAMES:
try:
Expand Down Expand Up @@ -341,6 +342,7 @@ def should_test_config_in_fhe(
"TweedieRegressor",
"PoissonRegressor",
"GammaRegressor",
"SGDRegressor",
}:
return True

Expand Down
1 change: 1 addition & 0 deletions docs/built-in-models/linear.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Concrete ML provides several of the most popular linear models for `regression`
| [Lasso](../developer-guide/api/concrete.ml.sklearn.linear_model.md#class-lasso) | [Lasso](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.Lasso.html#sklearn.linear_model.Lasso) |
| [Ridge](../developer-guide/api/concrete.ml.sklearn.linear_model.md#class-ridge) | [Ridge](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.Ridge.html#sklearn.linear_model.Ridge) |
| [ElasticNet](../developer-guide/api/concrete.ml.sklearn.linear_model.md#class-elasticnet) | [ElasticNet](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.ElasticNet.html#sklearn.linear_model.ElasticNet) |
| [SGDRegressor](../developer-guide/api/concrete.ml.sklearn.linear_model.md#class-sgdregressor) | [SGDRegressor](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.SGDRegressor.html) |

Using these models in FHE is extremely similar to what can be done with scikit-learn's [API](https://scikit-learn.org/stable/modules/classes.html#module-sklearn.linear_model), making it easy for data scientists who have used this framework to get started with Concrete ML.

Expand Down
4 changes: 3 additions & 1 deletion docs/developer-guide/api/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
- [`concrete.ml.sklearn.base`](./concrete.ml.sklearn.base.md#module-concretemlsklearnbase): Base classes for all estimators.
- [`concrete.ml.sklearn.glm`](./concrete.ml.sklearn.glm.md#module-concretemlsklearnglm): Implement sklearn's Generalized Linear Models (GLM).
- [`concrete.ml.sklearn.linear_model`](./concrete.ml.sklearn.linear_model.md#module-concretemlsklearnlinear_model): Implement sklearn linear model.
- [`concrete.ml.sklearn.neighbors`](./concrete.ml.sklearn.neighbors.md#module-concretemlsklearnneighbors): Implement sklearn linear model.
- [`concrete.ml.sklearn.neighbors`](./concrete.ml.sklearn.neighbors.md#module-concretemlsklearnneighbors): Implement sklearn neighbors model.
- [`concrete.ml.sklearn.qnn`](./concrete.ml.sklearn.qnn.md#module-concretemlsklearnqnn): Scikit-learn interface for fully-connected quantized neural networks.
- [`concrete.ml.sklearn.qnn_module`](./concrete.ml.sklearn.qnn_module.md#module-concretemlsklearnqnn_module): Sparse Quantized Neural Network torch module.
- [`concrete.ml.sklearn.rf`](./concrete.ml.sklearn.rf.md#module-concretemlsklearnrf): Implement RandomForest models.
Expand Down Expand Up @@ -182,6 +182,7 @@
- [`base.SklearnLinearClassifierMixin`](./concrete.ml.sklearn.base.md#class-sklearnlinearclassifiermixin): A Mixin class for sklearn linear classifiers with FHE.
- [`base.SklearnLinearModelMixin`](./concrete.ml.sklearn.base.md#class-sklearnlinearmodelmixin): A Mixin class for sklearn linear models with FHE.
- [`base.SklearnLinearRegressorMixin`](./concrete.ml.sklearn.base.md#class-sklearnlinearregressormixin): A Mixin class for sklearn linear regressors with FHE.
- [`base.SklearnSGDRegressorMixin`](./concrete.ml.sklearn.base.md#class-sklearnsgdregressormixin): A Mixin class for sklearn SGD regressors with FHE.
- [`glm.GammaRegressor`](./concrete.ml.sklearn.glm.md#class-gammaregressor): A Gamma regression model with FHE.
- [`glm.PoissonRegressor`](./concrete.ml.sklearn.glm.md#class-poissonregressor): A Poisson regression model with FHE.
- [`glm.TweedieRegressor`](./concrete.ml.sklearn.glm.md#class-tweedieregressor): A Tweedie regression model with FHE.
Expand All @@ -190,6 +191,7 @@
- [`linear_model.LinearRegression`](./concrete.ml.sklearn.linear_model.md#class-linearregression): A linear regression model with FHE.
- [`linear_model.LogisticRegression`](./concrete.ml.sklearn.linear_model.md#class-logisticregression): A logistic regression model with FHE.
- [`linear_model.Ridge`](./concrete.ml.sklearn.linear_model.md#class-ridge): A Ridge regression model with FHE.
- [`linear_model.SGDRegressor`](./concrete.ml.sklearn.linear_model.md#class-sgdregressor): An FHE linear regression model fitted with stochastic gradient descent.
- [`neighbors.KNeighborsClassifier`](./concrete.ml.sklearn.neighbors.md#class-kneighborsclassifier): A k-nearest neighbors classifier model with FHE.
- [`qnn.NeuralNetClassifier`](./concrete.ml.sklearn.qnn.md#class-neuralnetclassifier): A Fully-Connected Neural Network classifier with FHE.
- [`qnn.NeuralNetRegressor`](./concrete.ml.sklearn.qnn.md#class-neuralnetregressor): A Fully-Connected Neural Network regressor with FHE.
Expand Down
42 changes: 21 additions & 21 deletions docs/developer-guide/api/concrete.ml.common.utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Utils that can be re-used by other pieces of code in the module.

______________________________________________________________________

<a href="../../../src/concrete/ml/common/utils.py#L94"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/common/utils.py#L95"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

## <kbd>function</kbd> `replace_invalid_arg_name_chars`

Expand All @@ -39,7 +39,7 @@ This does not check that the starting character of arg_name is valid.

______________________________________________________________________

<a href="../../../src/concrete/ml/common/utils.py#L113"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/common/utils.py#L114"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

## <kbd>function</kbd> `generate_proxy_function`

Expand All @@ -65,7 +65,7 @@ This returns a runtime compiled function with the sanitized argument names passe

______________________________________________________________________

<a href="../../../src/concrete/ml/common/utils.py#L154"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/common/utils.py#L155"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

## <kbd>function</kbd> `get_onnx_opset_version`

Expand All @@ -85,7 +85,7 @@ Return the ONNX opset_version.

______________________________________________________________________

<a href="../../../src/concrete/ml/common/utils.py#L169"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/common/utils.py#L170"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

## <kbd>function</kbd> `manage_parameters_for_pbs_errors`

Expand Down Expand Up @@ -122,7 +122,7 @@ Note that global_p_error is currently set to 0 in the FHE simulation mode.

______________________________________________________________________

<a href="../../../src/concrete/ml/common/utils.py#L214"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/common/utils.py#L215"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

## <kbd>function</kbd> `check_there_is_no_p_error_options_in_configuration`

Expand All @@ -140,7 +140,7 @@ It would be dangerous, since we set them in direct arguments in our calls to Con

______________________________________________________________________

<a href="../../../src/concrete/ml/common/utils.py#L235"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/common/utils.py#L236"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

## <kbd>function</kbd> `get_model_class`

Expand All @@ -159,7 +159,7 @@ The model's class.

______________________________________________________________________

<a href="../../../src/concrete/ml/common/utils.py#L257"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/common/utils.py#L258"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

## <kbd>function</kbd> `is_model_class_in_a_list`

Expand All @@ -179,7 +179,7 @@ If the model's class is in the list or not.

______________________________________________________________________

<a href="../../../src/concrete/ml/common/utils.py#L271"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/common/utils.py#L272"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

## <kbd>function</kbd> `get_model_name`

Expand All @@ -198,7 +198,7 @@ the model's name.

______________________________________________________________________

<a href="../../../src/concrete/ml/common/utils.py#L284"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/common/utils.py#L285"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

## <kbd>function</kbd> `is_classifier_or_partial_classifier`

Expand All @@ -218,7 +218,7 @@ Indicate if the model class represents a classifier.

______________________________________________________________________

<a href="../../../src/concrete/ml/common/utils.py#L296"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/common/utils.py#L297"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

## <kbd>function</kbd> `is_regressor_or_partial_regressor`

Expand All @@ -238,7 +238,7 @@ Indicate if the model class represents a regressor.

______________________________________________________________________

<a href="../../../src/concrete/ml/common/utils.py#L308"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/common/utils.py#L309"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

## <kbd>function</kbd> `is_pandas_dataframe`

Expand All @@ -260,7 +260,7 @@ This function is inspired from Scikit-Learn's test validation tools and avoids t

______________________________________________________________________

<a href="../../../src/concrete/ml/common/utils.py#L324"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/common/utils.py#L325"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

## <kbd>function</kbd> `is_pandas_series`

Expand All @@ -282,7 +282,7 @@ This function is inspired from Scikit-Learn's test validation tools and avoids t

______________________________________________________________________

<a href="../../../src/concrete/ml/common/utils.py#L340"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/common/utils.py#L341"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

## <kbd>function</kbd> `is_pandas_type`

Expand All @@ -302,7 +302,7 @@ Indicate if the input container is a Pandas DataFrame or Series.

______________________________________________________________________

<a href="../../../src/concrete/ml/common/utils.py#L435"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/common/utils.py#L436"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

## <kbd>function</kbd> `check_dtype_and_cast`

Expand Down Expand Up @@ -334,7 +334,7 @@ If values types don't match with any supported type or the expected dtype, raise

______________________________________________________________________

<a href="../../../src/concrete/ml/common/utils.py#L487"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/common/utils.py#L488"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

## <kbd>function</kbd> `compute_bits_precision`

Expand All @@ -354,7 +354,7 @@ Compute the number of bits required to represent x.

______________________________________________________________________

<a href="../../../src/concrete/ml/common/utils.py#L499"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/common/utils.py#L500"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

## <kbd>function</kbd> `is_brevitas_model`

Expand All @@ -374,7 +374,7 @@ Check if a model is a Brevitas type.

______________________________________________________________________

<a href="../../../src/concrete/ml/common/utils.py#L517"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/common/utils.py#L518"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

## <kbd>function</kbd> `to_tuple`

Expand All @@ -394,7 +394,7 @@ Make the input a tuple if it is not already the case.

______________________________________________________________________

<a href="../../../src/concrete/ml/common/utils.py#L533"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/common/utils.py#L534"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

## <kbd>function</kbd> `all_values_are_integers`

Expand All @@ -414,7 +414,7 @@ Indicate if all unpacked values are of a supported integer dtype.

______________________________________________________________________

<a href="../../../src/concrete/ml/common/utils.py#L546"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/common/utils.py#L547"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

## <kbd>function</kbd> `all_values_are_floats`

Expand All @@ -434,7 +434,7 @@ Indicate if all unpacked values are of a supported float dtype.

______________________________________________________________________

<a href="../../../src/concrete/ml/common/utils.py#L559"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/common/utils.py#L560"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

## <kbd>function</kbd> `all_values_are_of_dtype`

Expand All @@ -455,7 +455,7 @@ Indicate if all unpacked values are of the specified dtype(s).

______________________________________________________________________

<a href="../../../src/concrete/ml/common/utils.py#L51"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/common/utils.py#L52"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

## <kbd>class</kbd> `FheMode`

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ Get the post-processing parameters.

______________________________________________________________________

<a href="../../../src/concrete/ml/quantization/quantized_module.py#L703"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../../src/concrete/ml/quantization/quantized_module.py#L698"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

### <kbd>method</kbd> `bitwidth_and_range_report`

Expand Down
Loading

0 comments on commit abb143c

Please sign in to comment.