diff --git a/docs/references/api/README.md b/docs/references/api/README.md
index 45b66d6fb..0922eac71 100644
--- a/docs/references/api/README.md
+++ b/docs/references/api/README.md
@@ -60,6 +60,7 @@
- [`decoder.ConcreteDecoder`](./concrete.ml.common.serialization.decoder.md#class-concretedecoder): Custom json decoder to handle non-native types found in serialized Concrete ML objects.
- [`encoder.ConcreteEncoder`](./concrete.ml.common.serialization.encoder.md#class-concreteencoder): Custom json encoder to handle non-native types found in serialized Concrete ML objects.
- [`utils.FheMode`](./concrete.ml.common.utils.md#class-fhemode): Enum representing the execution mode.
+- [`fhe_client_server.DeploymentMode`](./concrete.ml.deployment.fhe_client_server.md#class-deploymentmode): Mode for the FHE API.
- [`fhe_client_server.FHEModelClient`](./concrete.ml.deployment.fhe_client_server.md#class-fhemodelclient): Client API to encrypt and decrypt FHE data.
- [`fhe_client_server.FHEModelDev`](./concrete.ml.deployment.fhe_client_server.md#class-fhemodeldev): Dev API to save the model and then load and run the FHE circuit.
- [`fhe_client_server.FHEModelServer`](./concrete.ml.deployment.fhe_client_server.md#class-fhemodelserver): Server API to load and run the FHE circuit.
@@ -84,6 +85,8 @@
- [`torch_models.FCSeq`](./concrete.ml.pytest.torch_models.md#class-fcseq): Torch model that should generate MatMul->Add ONNX patterns.
- [`torch_models.FCSeqAddBiasVec`](./concrete.ml.pytest.torch_models.md#class-fcseqaddbiasvec): Torch model that should generate MatMul->Add ONNX patterns.
- [`torch_models.FCSmall`](./concrete.ml.pytest.torch_models.md#class-fcsmall): Torch model for the tests.
+- [`torch_models.IdentityExpandModel`](./concrete.ml.pytest.torch_models.md#class-identityexpandmodel): Model that only adds an empty dimension at axis 0.
+- [`torch_models.IdentityExpandMultiOutputModel`](./concrete.ml.pytest.torch_models.md#class-identityexpandmultioutputmodel): Model that only adds an empty dimension at axis 0, and returns the initial input as well.
- [`torch_models.ManualLogisticRegressionTraining`](./concrete.ml.pytest.torch_models.md#class-manuallogisticregressiontraining): PyTorch module for performing SGD training.
- [`torch_models.MultiInputNN`](./concrete.ml.pytest.torch_models.md#class-multiinputnn): Torch model to test multiple inputs forward.
- [`torch_models.MultiInputNNConfigurable`](./concrete.ml.pytest.torch_models.md#class-multiinputnnconfigurable): Torch model to test multiple inputs forward.
@@ -372,6 +375,7 @@
- [`tree_to_numpy.add_transpose_after_last_node`](./concrete.ml.sklearn.tree_to_numpy.md#function-add_transpose_after_last_node): Add transpose after last node.
- [`tree_to_numpy.assert_add_node_and_constant_in_xgboost_regressor_graph`](./concrete.ml.sklearn.tree_to_numpy.md#function-assert_add_node_and_constant_in_xgboost_regressor_graph): Assert if an Add node with a specific constant exists in the ONNX graph.
- [`tree_to_numpy.get_onnx_model`](./concrete.ml.sklearn.tree_to_numpy.md#function-get_onnx_model): Create ONNX model with Hummingbird convert method.
+- [`tree_to_numpy.onnx_fp32_model_to_quantized_model`](./concrete.ml.sklearn.tree_to_numpy.md#function-onnx_fp32_model_to_quantized_model): Build a FHE-compliant onnx-model using a fitted scikit-learn model.
- [`tree_to_numpy.preprocess_tree_predictions`](./concrete.ml.sklearn.tree_to_numpy.md#function-preprocess_tree_predictions): Apply post-processing from the graph.
- [`tree_to_numpy.tree_onnx_graph_preprocessing`](./concrete.ml.sklearn.tree_to_numpy.md#function-tree_onnx_graph_preprocessing): Apply pre-processing onto the ONNX graph.
- [`tree_to_numpy.tree_to_numpy`](./concrete.ml.sklearn.tree_to_numpy.md#function-tree_to_numpy): Convert the tree inference to a numpy functions using Hummingbird.
diff --git a/docs/references/api/concrete.ml.common.utils.md b/docs/references/api/concrete.ml.common.utils.md
index e3ab5e60a..e60dde0f1 100644
--- a/docs/references/api/concrete.ml.common.utils.md
+++ b/docs/references/api/concrete.ml.common.utils.md
@@ -17,7 +17,7 @@ Utils that can be re-used by other pieces of code in the module.
______________________________________________________________________
-
+
## function `replace_invalid_arg_name_chars`
@@ -39,7 +39,7 @@ This does not check that the starting character of arg_name is valid.
______________________________________________________________________
-
+
## function `generate_proxy_function`
@@ -65,7 +65,7 @@ This returns a runtime compiled function with the sanitized argument names passe
______________________________________________________________________
-
+
## function `get_onnx_opset_version`
@@ -85,7 +85,7 @@ Return the ONNX opset_version.
______________________________________________________________________
-
+
## function `manage_parameters_for_pbs_errors`
@@ -122,7 +122,7 @@ Note that global_p_error is currently set to 0 in the FHE simulation mode.
______________________________________________________________________
-
+
## function `check_there_is_no_p_error_options_in_configuration`
@@ -140,7 +140,7 @@ It would be dangerous, since we set them in direct arguments in our calls to Con
______________________________________________________________________
-
+
## function `get_model_class`
@@ -159,7 +159,7 @@ The model's class.
______________________________________________________________________
-
+
## function `is_model_class_in_a_list`
@@ -179,7 +179,7 @@ If the model's class is in the list or not.
______________________________________________________________________
-
+
## function `get_model_name`
@@ -198,7 +198,7 @@ the model's name.
______________________________________________________________________
-
+
## function `is_classifier_or_partial_classifier`
@@ -218,7 +218,7 @@ Indicate if the model class represents a classifier.
______________________________________________________________________
-
+
## function `is_regressor_or_partial_regressor`
@@ -238,7 +238,7 @@ Indicate if the model class represents a regressor.
______________________________________________________________________
-
+
## function `is_pandas_dataframe`
@@ -260,7 +260,7 @@ This function is inspired from Scikit-Learn's test validation tools and avoids t
______________________________________________________________________
-
+
## function `is_pandas_series`
@@ -282,7 +282,7 @@ This function is inspired from Scikit-Learn's test validation tools and avoids t
______________________________________________________________________
-
+
## function `is_pandas_type`
@@ -302,7 +302,7 @@ Indicate if the input container is a Pandas DataFrame or Series.
______________________________________________________________________
-
+
## function `check_dtype_and_cast`
@@ -334,7 +334,7 @@ If values types don't match with any supported type or the expected dtype, raise
______________________________________________________________________
-
+
## function `compute_bits_precision`
@@ -354,7 +354,7 @@ Compute the number of bits required to represent x.
______________________________________________________________________
-
+
## function `is_brevitas_model`
@@ -374,7 +374,7 @@ Check if a model is a Brevitas type.
______________________________________________________________________
-
+
## function `to_tuple`
@@ -394,7 +394,7 @@ Make the input a tuple if it is not already the case.
______________________________________________________________________
-
+
## function `all_values_are_integers`
@@ -414,7 +414,7 @@ Indicate if all unpacked values are of a supported integer dtype.
______________________________________________________________________
-
+
## function `all_values_are_floats`
@@ -434,12 +434,16 @@ Indicate if all unpacked values are of a supported float dtype.
______________________________________________________________________
-
+
## function `all_values_are_of_dtype`
```python
-all_values_are_of_dtype(*values: Any, dtypes: Union[str, List[str]]) → bool
+all_values_are_of_dtype(
+ *values: Any,
+ dtypes: Union[str, List[str]],
+ allow_none: bool = False
+) → bool
```
Indicate if all unpacked values are of the specified dtype(s).
@@ -448,6 +452,7 @@ Indicate if all unpacked values are of the specified dtype(s).
- `*values (Any)`: The values to consider.
- `dtypes` (Union\[str, List\[str\]\]): The dtype(s) to consider.
+- `allow_none` (bool): Indicate if the values can be None.
**Returns:**
@@ -455,7 +460,7 @@ Indicate if all unpacked values are of the specified dtype(s).
______________________________________________________________________
-
+
## function `array_allclose_and_same_shape`
@@ -485,7 +490,7 @@ Check if two numpy arrays are equal within a tolerances and have the same shape.
______________________________________________________________________
-
+
## function `process_rounding_threshold_bits`
@@ -511,7 +516,7 @@ Check and process the rounding_threshold_bits parameter.
______________________________________________________________________
-
+
## class `FheMode`
diff --git a/docs/references/api/concrete.ml.deployment.fhe_client_server.md b/docs/references/api/concrete.ml.deployment.fhe_client_server.md
index 812acc191..569c520da 100644
--- a/docs/references/api/concrete.ml.deployment.fhe_client_server.md
+++ b/docs/references/api/concrete.ml.deployment.fhe_client_server.md
@@ -12,7 +12,7 @@ APIs for FHE deployment.
______________________________________________________________________
-
+
## function `check_concrete_versions`
@@ -34,13 +34,21 @@ This function loads the version JSON file found in client.zip or server.zip file
______________________________________________________________________
-
+
+
+## class `DeploymentMode`
+
+Mode for the FHE API.
+
+______________________________________________________________________
+
+
## class `FHEModelServer`
Server API to load and run the FHE circuit.
-
+
### method `__init__`
@@ -56,7 +64,7 @@ Initialize the FHE API.
______________________________________________________________________
-
+
### method `load`
@@ -68,37 +76,37 @@ Load the circuit.
______________________________________________________________________
-
+
### method `run`
```python
run(
- serialized_encrypted_quantized_data: bytes,
+ serialized_encrypted_quantized_data: Union[bytes, Value, Tuple[bytes, ], Tuple[Value, ]],
serialized_evaluation_keys: bytes
-) → bytes
+) → Union[bytes, Value, Tuple[bytes, ], Tuple[Value, ]]
```
Run the model on the server over encrypted data.
**Args:**
-- `serialized_encrypted_quantized_data` (bytes): the encrypted, quantized and serialized data
-- `serialized_evaluation_keys` (bytes): the serialized evaluation keys
+- `serialized_encrypted_quantized_data` (Union\[bytes, fhe.Value, Tuple\[bytes, ...\], Tuple\[fhe.Value, ...\]\]): The encrypted and quantized values to consider. If these values are serialized (in bytes), they are first deserialized.
+- `serialized_evaluation_keys` (bytes): The evaluation keys. If they are serialized (in bytes), they are first deserialized.
**Returns:**
-- `bytes`: the result of the model
+- `Union[bytes, fhe.Value, Tuple[bytes, ...], Tuple[fhe.Value, ...]]`: The model's encrypted and quantized results. If the inputs were initially serialized, the outputs are also serialized.
______________________________________________________________________
-
+
## class `FHEModelDev`
Dev API to save the model and then load and run the FHE circuit.
-
+
### method `__init__`
@@ -115,33 +123,38 @@ Initialize the FHE API.
______________________________________________________________________
-
+
### method `save`
```python
-save(via_mlir: bool = False)
+save(
+ mode: DeploymentMode = ,
+ via_mlir: bool = False
+)
```
Export all needed artifacts for the client and server.
**Arguments:**
-- `via_mlir` (bool): serialize with `via_mlir` option from Concrete-Python. For more details on the topic please refer to Concrete-Python's documentation.
+- `mode` (DeploymentMode): the mode to save the FHE circuit, either "inference" or "training".
+- `via_mlir` (bool): serialize with `via_mlir` option from Concrete-Python.
**Raises:**
-- `Exception`: path_dir is not empty
+- `Exception`: path_dir is not empty or training module does not exist
+- `ValueError`: if mode is not "inference" or "training"
______________________________________________________________________
-
+
## class `FHEModelClient`
Client API to encrypt and decrypt FHE data.
-
+
### method `__init__`
@@ -158,49 +171,51 @@ Initialize the FHE API.
______________________________________________________________________
-
+
### method `deserialize_decrypt`
```python
-deserialize_decrypt(serialized_encrypted_quantized_result: bytes) → ndarray
+deserialize_decrypt(
+ *serialized_encrypted_quantized_result: Optional[bytes]
+) → Union[Any, Tuple[Any, ]]
```
Deserialize and decrypt the values.
**Args:**
-- `serialized_encrypted_quantized_result` (bytes): the serialized, encrypted and quantized result
+- `serialized_encrypted_quantized_result` (Optional\[bytes\]): The serialized, encrypted and quantized values.
**Returns:**
-- `numpy.ndarray`: the decrypted and deserialized values
+- `Union[Any, Tuple[Any, ...]]`: The decrypted and deserialized values.
______________________________________________________________________
-
+
### method `deserialize_decrypt_dequantize`
```python
deserialize_decrypt_dequantize(
- serialized_encrypted_quantized_result: bytes
-) → ndarray
+ *serialized_encrypted_quantized_result: Optional[bytes]
+) → Union[ndarray, Tuple[ndarray, ]]
```
Deserialize, decrypt and de-quantize the values.
**Args:**
-- `serialized_encrypted_quantized_result` (bytes): the serialized, encrypted and quantized result
+- `serialized_encrypted_quantized_result` (Optional\[bytes\]): The serialized, encrypted and quantized result
**Returns:**
-- `numpy.ndarray`: the decrypted (de-quantized) values
+- `Union[numpy.ndarray, Tuple[numpy.ndarray, ...]]`: The clear float values.
______________________________________________________________________
-
+
### method `generate_private_and_evaluation_keys`
@@ -216,7 +231,7 @@ Generate the private and evaluation keys.
______________________________________________________________________
-
+
### method `get_serialized_evaluation_keys`
@@ -232,7 +247,7 @@ Get the serialized evaluation keys.
______________________________________________________________________
-
+
### method `load`
@@ -244,20 +259,22 @@ Load the quantizers along with the FHE specs.
______________________________________________________________________
-
+
### method `quantize_encrypt_serialize`
```python
-quantize_encrypt_serialize(x: ndarray) → bytes
+quantize_encrypt_serialize(
+ *x: Optional[ndarray]
+) → Union[bytes, NoneType, Tuple[Union[bytes, NoneType], ]]
```
Quantize, encrypt and serialize the values.
**Args:**
-- `x` (numpy.ndarray): the values to quantize, encrypt and serialize
+- `x` (Optional\[numpy.ndarray\]): The values to quantize, encrypt and serialize.
**Returns:**
-- `bytes`: the quantized, encrypted and serialized values
+- `Union[bytes, Tuple[bytes, ...]]`: The quantized, encrypted and serialized values.
diff --git a/docs/references/api/concrete.ml.onnx.convert.md b/docs/references/api/concrete.ml.onnx.convert.md
index 8172c8020..11e3b4788 100644
--- a/docs/references/api/concrete.ml.onnx.convert.md
+++ b/docs/references/api/concrete.ml.onnx.convert.md
@@ -59,7 +59,7 @@ Get the numpy equivalent forward of the provided torch Module.
______________________________________________________________________
-
+
## function `preprocess_onnx_model`
@@ -84,7 +84,7 @@ Get the numpy equivalent forward of the provided ONNX model.
______________________________________________________________________
-
+
## function `get_equivalent_numpy_forward_from_onnx`
@@ -108,7 +108,7 @@ Get the numpy equivalent forward of the provided ONNX model.
______________________________________________________________________
-
+
## function `get_equivalent_numpy_forward_from_onnx_tree`
diff --git a/docs/references/api/concrete.ml.onnx.onnx_impl_utils.md b/docs/references/api/concrete.ml.onnx.onnx_impl_utils.md
index 77d914870..f1dbf1c1d 100644
--- a/docs/references/api/concrete.ml.onnx.onnx_impl_utils.md
+++ b/docs/references/api/concrete.ml.onnx.onnx_impl_utils.md
@@ -132,7 +132,7 @@ This constant can be a tensor of the same shape as the input or a scalar.
______________________________________________________________________
-
+
## function `rounded_comparison`
diff --git a/docs/references/api/concrete.ml.onnx.ops_impl.md b/docs/references/api/concrete.ml.onnx.ops_impl.md
index a9952eb38..0ceb72791 100644
--- a/docs/references/api/concrete.ml.onnx.ops_impl.md
+++ b/docs/references/api/concrete.ml.onnx.ops_impl.md
@@ -8,7 +8,7 @@ ONNX ops implementation in Python + NumPy.
______________________________________________________________________
-
+
## function `cast_to_float`
@@ -28,7 +28,7 @@ Cast values to floating points.
______________________________________________________________________
-
+
## function `onnx_func_raw_args`
@@ -49,7 +49,7 @@ Decorate a numpy onnx function to flag the raw/non quantized inputs.
______________________________________________________________________
-
+
## function `numpy_where_body`
@@ -73,7 +73,7 @@ This function is not mapped to any ONNX operator (as opposed to numpy_where). It
______________________________________________________________________
-
+
## function `numpy_where`
@@ -95,7 +95,7 @@ Compute the equivalent of numpy.where.
______________________________________________________________________
-
+
## function `numpy_add`
@@ -118,7 +118,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Add-13
______________________________________________________________________
-
+
## function `numpy_constant`
@@ -140,7 +140,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Constant-13
______________________________________________________________________
-
+
## function `numpy_gemm`
@@ -176,7 +176,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gemm-13
______________________________________________________________________
-
+
## function `numpy_matmul`
@@ -199,7 +199,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#MatMul-13
______________________________________________________________________
-
+
## function `numpy_relu`
@@ -221,7 +221,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Relu-14
______________________________________________________________________
-
+
## function `numpy_sigmoid`
@@ -243,7 +243,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sigmoid-13
______________________________________________________________________
-
+
## function `numpy_softmax`
@@ -269,7 +269,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#softmax-13
______________________________________________________________________
-
+
## function `numpy_cos`
@@ -291,7 +291,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Cos-7
______________________________________________________________________
-
+
## function `numpy_cosh`
@@ -313,7 +313,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Cosh-9
______________________________________________________________________
-
+
## function `numpy_sin`
@@ -335,7 +335,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sin-7
______________________________________________________________________
-
+
## function `numpy_sinh`
@@ -357,7 +357,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sinh-9
______________________________________________________________________
-
+
## function `numpy_tan`
@@ -379,7 +379,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Tan-7
______________________________________________________________________
-
+
## function `numpy_tanh`
@@ -401,7 +401,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Tanh-13
______________________________________________________________________
-
+
## function `numpy_acos`
@@ -423,7 +423,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Acos-7
______________________________________________________________________
-
+
## function `numpy_acosh`
@@ -445,7 +445,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Acosh-9
______________________________________________________________________
-
+
## function `numpy_asin`
@@ -467,7 +467,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Asin-7
______________________________________________________________________
-
+
## function `numpy_asinh`
@@ -489,7 +489,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Asinh-9
______________________________________________________________________
-
+
## function `numpy_atan`
@@ -511,7 +511,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Atan-7
______________________________________________________________________
-
+
## function `numpy_atanh`
@@ -533,7 +533,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Atanh-9
______________________________________________________________________
-
+
## function `numpy_elu`
@@ -556,7 +556,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Elu-6
______________________________________________________________________
-
+
## function `numpy_selu`
@@ -584,7 +584,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Selu-6
______________________________________________________________________
-
+
## function `numpy_celu`
@@ -607,7 +607,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Celu-12
______________________________________________________________________
-
+
## function `numpy_leakyrelu`
@@ -630,7 +630,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#LeakyRelu-6
______________________________________________________________________
-
+
## function `numpy_thresholdedrelu`
@@ -653,7 +653,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ThresholdedRelu-10
______________________________________________________________________
-
+
## function `numpy_hardsigmoid`
@@ -681,7 +681,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#HardSigmoid-6
______________________________________________________________________
-
+
## function `numpy_softplus`
@@ -703,7 +703,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Softplus-1
______________________________________________________________________
-
+
## function `numpy_abs`
@@ -725,7 +725,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Abs-13
______________________________________________________________________
-
+
## function `numpy_div`
@@ -748,7 +748,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Div-14
______________________________________________________________________
-
+
## function `numpy_mul`
@@ -771,7 +771,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Mul-14
______________________________________________________________________
-
+
## function `numpy_sub`
@@ -794,7 +794,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sub-14
______________________________________________________________________
-
+
## function `numpy_log`
@@ -816,7 +816,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Log-13
______________________________________________________________________
-
+
## function `numpy_erf`
@@ -838,7 +838,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Erf-13
______________________________________________________________________
-
+
## function `numpy_hardswish`
@@ -860,7 +860,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#hardswish-14
______________________________________________________________________
-
+
## function `numpy_exp`
@@ -882,7 +882,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Exp-13
______________________________________________________________________
-
+
## function `numpy_equal`
@@ -905,7 +905,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Equal-11
______________________________________________________________________
-
+
## function `rounded_numpy_equal_for_trees`
@@ -933,7 +933,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Equal-11
______________________________________________________________________
-
+
## function `numpy_equal_float`
@@ -956,7 +956,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Equal-13
______________________________________________________________________
-
+
## function `numpy_not`
@@ -978,7 +978,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Not-1
______________________________________________________________________
-
+
## function `numpy_not_float`
@@ -1000,7 +1000,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Not-1
______________________________________________________________________
-
+
## function `numpy_greater`
@@ -1023,7 +1023,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Greater-13
______________________________________________________________________
-
+
## function `numpy_greater_float`
@@ -1046,7 +1046,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Greater-13
______________________________________________________________________
-
+
## function `numpy_greater_or_equal`
@@ -1069,7 +1069,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#GreaterOrEqual-12
______________________________________________________________________
-
+
## function `numpy_greater_or_equal_float`
@@ -1092,7 +1092,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#GreaterOrEqual-12
______________________________________________________________________
-
+
## function `numpy_less`
@@ -1115,7 +1115,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Less-13
______________________________________________________________________
-
+
## function `rounded_numpy_less_for_trees`
@@ -1143,7 +1143,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Less-13
______________________________________________________________________
-
+
## function `numpy_less_float`
@@ -1166,7 +1166,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Less-13
______________________________________________________________________
-
+
## function `numpy_less_or_equal`
@@ -1189,7 +1189,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#LessOrEqual-12
______________________________________________________________________
-
+
## function `rounded_numpy_less_or_equal_for_trees`
@@ -1217,7 +1217,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#LessOrEqual-12
______________________________________________________________________
-
+
## function `numpy_less_or_equal_float`
@@ -1240,7 +1240,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#LessOrEqual-12
______________________________________________________________________
-
+
## function `numpy_identity`
@@ -1262,7 +1262,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-14
______________________________________________________________________
-
+
## function `numpy_transpose`
@@ -1285,7 +1285,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Transpose-13
______________________________________________________________________
-
+
## function `numpy_conv`
@@ -1326,7 +1326,7 @@ See: https://github.com/onnx/onnx/blob/main/docs/Operators.md#Conv
______________________________________________________________________
-
+
## function `numpy_avgpool`
@@ -1365,7 +1365,7 @@ See: https://github.com/onnx/onnx/blob/main/docs/Operators.md#AveragePool
______________________________________________________________________
-
+
## function `numpy_maxpool`
@@ -1406,7 +1406,7 @@ See: https://github.com/onnx/onnx/blob/main/docs/Operators.md#MaxPool
______________________________________________________________________
-
+
## function `numpy_cast`
@@ -1429,7 +1429,7 @@ This function supports casting to booleans, floats, and double for traced values
______________________________________________________________________
-
+
## function `numpy_batchnorm`
@@ -1471,7 +1471,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#BatchNormalization-
______________________________________________________________________
-
+
## function `numpy_flatten`
@@ -1494,7 +1494,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Flatten-13.
______________________________________________________________________
-
+
## function `numpy_or`
@@ -1517,7 +1517,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Or-7
______________________________________________________________________
-
+
## function `numpy_or_float`
@@ -1540,7 +1540,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Or-7
______________________________________________________________________
-
+
## function `numpy_round`
@@ -1562,7 +1562,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Round-11 Remark tha
______________________________________________________________________
-
+
## function `numpy_pow`
@@ -1585,7 +1585,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pow-13
______________________________________________________________________
-
+
## function `numpy_floor`
@@ -1607,7 +1607,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Floor-1
______________________________________________________________________
-
+
## function `numpy_max`
@@ -1632,7 +1632,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Max-1
______________________________________________________________________
-
+
## function `numpy_min`
@@ -1657,7 +1657,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Max-1
______________________________________________________________________
-
+
## function `numpy_sign`
@@ -1679,7 +1679,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sign-9
______________________________________________________________________
-
+
## function `numpy_neg`
@@ -1701,7 +1701,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sign-9
______________________________________________________________________
-
+
## function `numpy_concatenate`
@@ -1724,7 +1724,7 @@ See https://github.com/onnx/onnx/blob/main/docs/Changelog.md#concat-13
______________________________________________________________________
-
+
## function `numpy_unfold`
@@ -1767,9 +1767,11 @@ ______________________________________________________________________
Type construct that marks an ndarray as a raw output of a quantized op.
+A raw output is an output that is a clear constant such as a shape, a constant float, an index..
+
______________________________________________________________________
-
+
## class `ONNXMixedFunction`
@@ -1777,7 +1779,7 @@ A mixed quantized-raw valued onnx function.
ONNX functions will take inputs which can be either quantized or float. Some functions only take quantized inputs, but some functions take both types. For mixed functions we need to tag the parameters that do not need quantization. Thus quantized ops can know which inputs are not QuantizedArray and we avoid unnecessary wrapping of float values as QuantizedArrays.
-
+
### method `__init__`
diff --git a/docs/references/api/concrete.ml.pandas.client_engine.md b/docs/references/api/concrete.ml.pandas.client_engine.md
index 35d32eed0..43aeceec1 100644
--- a/docs/references/api/concrete.ml.pandas.client_engine.md
+++ b/docs/references/api/concrete.ml.pandas.client_engine.md
@@ -12,13 +12,13 @@ Define the framework used for managing keys (encrypt, decrypt) for encrypted dat
______________________________________________________________________
-
+
## class `ClientEngine`
Define a framework that manages keys.
-
+
### method `__init__`
@@ -28,7 +28,7 @@ __init__(keygen: bool = True, keys_path: Optional[Path, str] = None)
______________________________________________________________________
-
+
### method `decrypt_to_pandas`
@@ -48,12 +48,15 @@ Decrypt an encrypted data-frame using the loaded client and return a Pandas data
______________________________________________________________________
-
+
### method `encrypt_from_pandas`
```python
-encrypt_from_pandas(pandas_dataframe: DataFrame) → EncryptedDataFrame
+encrypt_from_pandas(
+ pandas_dataframe: DataFrame,
+ schema: Optional[Dict] = None
+) → EncryptedDataFrame
```
Encrypt a Pandas data-frame using the loaded client.
@@ -61,6 +64,7 @@ Encrypt a Pandas data-frame using the loaded client.
**Args:**
- `pandas_dataframe` (DataFrame): The Pandas data-frame to encrypt.
+- `schema` (Optional\[Dict\]): The input schema to consider. Default to None.
**Returns:**
@@ -68,7 +72,7 @@ Encrypt a Pandas data-frame using the loaded client.
______________________________________________________________________
-
+
### method `keygen`
diff --git a/docs/references/api/concrete.ml.pytest.torch_models.md b/docs/references/api/concrete.ml.pytest.torch_models.md
index d8d0b68d5..55c5a540c 100644
--- a/docs/references/api/concrete.ml.pytest.torch_models.md
+++ b/docs/references/api/concrete.ml.pytest.torch_models.md
@@ -43,7 +43,7 @@ Forward pass.
**Returns:**
-- `Tuple[torch.Tensor. torch.Tensor]`: Output of the network.
+- `Tuple[torch.Tensor. torch.Tensor]`: Outputs of the network.
______________________________________________________________________
@@ -1563,3 +1563,63 @@ Forward pass.
**Returns:**
- `torch.Tensor`: The model's output.
+
+______________________________________________________________________
+
+
+
+## class `IdentityExpandModel`
+
+Model that only adds an empty dimension at axis 0.
+
+This model is mostly useful for testing the composition feature.
+
+______________________________________________________________________
+
+
+
+### method `forward`
+
+```python
+forward(x)
+```
+
+Forward pass.
+
+**Args:**
+
+- `x` (torch.Tensor): The input of the model.
+
+**Returns:**
+
+- `Tuple[torch.Tensor. torch.Tensor]`: Outputs of the network.
+
+______________________________________________________________________
+
+
+
+## class `IdentityExpandMultiOutputModel`
+
+Model that only adds an empty dimension at axis 0, and returns the initial input as well.
+
+This model is mostly useful for testing the composition feature.
+
+______________________________________________________________________
+
+
+
+### method `forward`
+
+```python
+forward(x)
+```
+
+Forward pass.
+
+**Args:**
+
+- `x` (torch.Tensor): The input of the model.
+
+**Returns:**
+
+- `Tuple[torch.Tensor. torch.Tensor]`: Outputs of the network.
diff --git a/docs/references/api/concrete.ml.pytest.utils.md b/docs/references/api/concrete.ml.pytest.utils.md
index 8f900ea55..ae87301c8 100644
--- a/docs/references/api/concrete.ml.pytest.utils.md
+++ b/docs/references/api/concrete.ml.pytest.utils.md
@@ -13,7 +13,7 @@ Common functions or lists for test files, which can't be put in fixtures.
______________________________________________________________________
-
+
## function `get_sklearn_linear_models_and_datasets`
@@ -43,7 +43,7 @@ Get the pytest parameters to use for testing linear models.
______________________________________________________________________
-
+
## function `get_sklearn_tree_models_and_datasets`
@@ -73,7 +73,7 @@ Get the pytest parameters to use for testing tree-based models.
______________________________________________________________________
-
+
## function `get_sklearn_neural_net_models_and_datasets`
@@ -103,7 +103,7 @@ Get the pytest parameters to use for testing neural network models.
______________________________________________________________________
-
+
## function `get_sklearn_neighbors_models_and_datasets`
@@ -133,7 +133,7 @@ Get the pytest parameters to use for testing neighbor models.
______________________________________________________________________
-
+
## function `get_sklearn_all_models_and_datasets`
@@ -163,7 +163,7 @@ Get the pytest parameters to use for testing all models available in Concrete ML
______________________________________________________________________
-
+
## function `instantiate_model_generic`
@@ -186,7 +186,7 @@ Instantiate any Concrete ML model type.
______________________________________________________________________
-
+
## function `data_calibration_processing`
@@ -212,7 +212,7 @@ Reduce size of the given data-set.
______________________________________________________________________
-
+
## function `load_torch_model`
@@ -240,7 +240,7 @@ Load an object saved with torch.save() from a file or dict.
______________________________________________________________________
-
+
## function `values_are_equal`
@@ -263,7 +263,7 @@ This method takes into account objects of type None, numpy.ndarray, numpy.floati
______________________________________________________________________
-
+
## function `check_serialization`
@@ -289,7 +289,7 @@ This function serializes all objects using the `dump`, `dumps`, `load` and `load
______________________________________________________________________
-
+
## function `get_random_samples`
@@ -314,7 +314,7 @@ Select `n_sample` random elements from a 2D NumPy array.
______________________________________________________________________
-
+
## function `pandas_dataframe_are_equal`
diff --git a/docs/references/api/concrete.ml.quantization.post_training.md b/docs/references/api/concrete.ml.quantization.post_training.md
index 80a814f35..5db84f8ee 100644
--- a/docs/references/api/concrete.ml.quantization.post_training.md
+++ b/docs/references/api/concrete.ml.quantization.post_training.md
@@ -122,7 +122,7 @@ Following https://arxiv.org/abs/1712.05877 guidelines.
**Args:**
-- `*calibration_data (numpy.ndarray)`: Data that will be used to compute the bounds, scales and zero point values for every quantized object.
+- `calibration_data` (numpy.ndarray): Data that will be used to compute the bounds, scales and zero point values for every quantized object.
**Returns:**
@@ -130,7 +130,7 @@ Following https://arxiv.org/abs/1712.05877 guidelines.
______________________________________________________________________
-
+
## class `PostTrainingAffineQuantization`
@@ -221,7 +221,7 @@ Following https://arxiv.org/abs/1712.05877 guidelines.
**Args:**
-- `*calibration_data (numpy.ndarray)`: Data that will be used to compute the bounds, scales and zero point values for every quantized object.
+- `calibration_data` (numpy.ndarray): Data that will be used to compute the bounds, scales and zero point values for every quantized object.
**Returns:**
@@ -229,7 +229,7 @@ Following https://arxiv.org/abs/1712.05877 guidelines.
______________________________________________________________________
-
+
## class `PostTrainingQATImporter`
@@ -305,7 +305,7 @@ Following https://arxiv.org/abs/1712.05877 guidelines.
**Args:**
-- `*calibration_data (numpy.ndarray)`: Data that will be used to compute the bounds, scales and zero point values for every quantized object.
+- `calibration_data` (numpy.ndarray): Data that will be used to compute the bounds, scales and zero point values for every quantized object.
**Returns:**
diff --git a/docs/references/api/concrete.ml.quantization.quantized_module.md b/docs/references/api/concrete.ml.quantization.quantized_module.md
index a90f6def9..bebd6f9e8 100644
--- a/docs/references/api/concrete.ml.quantization.quantized_module.md
+++ b/docs/references/api/concrete.ml.quantization.quantized_module.md
@@ -67,12 +67,12 @@ Get the post-processing parameters.
______________________________________________________________________
-
+
### method `bitwidth_and_range_report`
```python
-bitwidth_and_range_report() → Optional[Dict[str, Dict[str, Union[Tuple[int, ], int]]]]
+bitwidth_and_range_report() → Union[Dict[str, Dict[str, Union[Tuple[int, ], int]]], NoneType]
```
Report the ranges and bit-widths for layers that mix encrypted integer values.
@@ -83,7 +83,7 @@ Report the ranges and bit-widths for layers that mix encrypted integer values.
______________________________________________________________________
-
+
### method `check_model_is_compiled`
@@ -99,7 +99,7 @@ Check if the quantized module is compiled.
______________________________________________________________________
-
+
### method `compile`
@@ -139,7 +139,7 @@ Compile the module's forward function.
______________________________________________________________________
-
+
### method `dequantize_output`
@@ -159,7 +159,7 @@ Take the last layer q_out and use its de-quant function.
______________________________________________________________________
-
+
### method `dump`
@@ -175,7 +175,7 @@ Dump itself to a file.
______________________________________________________________________
-
+
### method `dump_dict`
@@ -191,7 +191,7 @@ Dump itself to a dict.
______________________________________________________________________
-
+
### method `dumps`
@@ -207,7 +207,7 @@ Dump itself to a string.
______________________________________________________________________
-
+
### method `forward`
@@ -235,7 +235,7 @@ This method executes the forward pass in the clear, with simulation or in FHE. I
______________________________________________________________________
-
+
### method `load_dict`
@@ -255,12 +255,12 @@ Load itself from a string.
______________________________________________________________________
-
+
### method `post_processing`
```python
-post_processing(values: ndarray) → ndarray
+post_processing(*values: ndarray) → Union[ndarray, Tuple[ndarray, ]]
```
Apply post-processing to the de-quantized values.
@@ -273,31 +273,33 @@ For quantized modules, there is no post-processing step but the method is kept t
**Returns:**
-- `numpy.ndarray`: The post-processed values.
+- `Union[numpy.ndarray, Tuple[numpy.ndarray, ...]]`: The post-processed values.
______________________________________________________________________
-
+
### method `quantize_input`
```python
-quantize_input(*x: ndarray) → Union[ndarray, Tuple[ndarray, ]]
+quantize_input(
+ *x: Optional[ndarray]
+) → Union[ndarray, Tuple[Union[ndarray, NoneType], ]]
```
Take the inputs in fp32 and quantize it using the learned quantization parameters.
**Args:**
-- `x` (numpy.ndarray): Floating point x.
+- `x` (Optional\[numpy.ndarray\]): Floating point x or None.
**Returns:**
-- `Union[numpy.ndarray, Tuple[numpy.ndarray, ...]]`: Quantized (numpy.int64) x.
+- `Union[numpy.ndarray, Tuple[numpy.ndarray, ...]]`: Quantized (numpy.int64) x, or None if the corresponding input is None.
______________________________________________________________________
-
+
### method `quantized_forward`
@@ -321,7 +323,7 @@ Forward function for the FHE circuit.
______________________________________________________________________
-
+
### method `set_inputs_quantization_parameters`
@@ -337,7 +339,7 @@ Set the quantization parameters for the module's inputs.
______________________________________________________________________
-
+
### method `set_reduce_sum_copy`
diff --git a/docs/references/api/concrete.ml.quantization.quantized_module_passes.md b/docs/references/api/concrete.ml.quantization.quantized_module_passes.md
index 79cc3c7fb..60a151c02 100644
--- a/docs/references/api/concrete.ml.quantization.quantized_module_passes.md
+++ b/docs/references/api/concrete.ml.quantization.quantized_module_passes.md
@@ -41,7 +41,7 @@ ______________________________________________________________________
### method `compute_op_predecessors`
```python
-compute_op_predecessors() → DefaultDict[Optional[QuantizedOp], List[Tuple[Optional[QuantizedOp], str]]]
+compute_op_predecessors() → DefaultDict[Union[QuantizedOp, NoneType], List[Tuple[Union[QuantizedOp, NoneType], str]]]
```
Compute the predecessors for each QuantizedOp in a QuantizedModule.
@@ -61,7 +61,7 @@ ______________________________________________________________________
```python
detect_patterns(
predecessors: DefaultDict[Optional[QuantizedOp], List[Tuple[Optional[QuantizedOp], str]]]
-) → Dict[QuantizedMixingOp, Tuple[List[Optional[QuantizedOp]], Optional[QuantizedOp]]]
+) → Dict[QuantizedMixingOp, Tuple[List[Union[QuantizedOp, NoneType]], Union[QuantizedOp, NoneType]]]
```
Detect the patterns that can be optimized with roundPBS in the QuantizedModule.
@@ -107,7 +107,7 @@ ______________________________________________________________________
### method `process`
```python
-process() → Dict[QuantizedMixingOp, Tuple[List[Optional[QuantizedOp]], Optional[QuantizedOp]]]
+process() → Dict[QuantizedMixingOp, Tuple[List[Union[QuantizedOp, NoneType]], Union[QuantizedOp, NoneType]]]
```
Analyze an ONNX graph and detect Gemm/Conv patterns that can use RoundPBS.
@@ -129,7 +129,7 @@ ______________________________________________________________________
```python
process_patterns(
valid_paths: Dict[QuantizedMixingOp, Tuple[List[Optional[QuantizedOp]], Optional[QuantizedOp]]]
-) → Dict[QuantizedMixingOp, Tuple[List[Optional[QuantizedOp]], Optional[QuantizedOp]]]
+) → Dict[QuantizedMixingOp, Tuple[List[Union[QuantizedOp, NoneType]], Union[QuantizedOp, NoneType]]]
```
Configure the rounding bits of roundPBS for the optimizable operations.
diff --git a/docs/references/api/concrete.ml.quantization.quantized_ops.md b/docs/references/api/concrete.ml.quantization.quantized_ops.md
index 3916a7441..e2e3bc371 100644
--- a/docs/references/api/concrete.ml.quantization.quantized_ops.md
+++ b/docs/references/api/concrete.ml.quantization.quantized_ops.md
@@ -1132,7 +1132,27 @@ Get the names of encrypted integer tensors that are used by this op.
______________________________________________________________________
-
+
+
+### method `calibrate`
+
+```python
+calibrate(*inputs: ndarray) → ndarray
+```
+
+Create corresponding QuantizedArray for the output of the activation function.
+
+**Args:**
+
+- `*inputs (numpy.ndarray)`: Calibration sample inputs.
+
+**Returns:**
+
+- `numpy.ndarray`: the output values for the provided calibration samples.
+
+______________________________________________________________________
+
+
## class `QuantizedFlatten`
@@ -1150,7 +1170,7 @@ Get the names of encrypted integer tensors that are used by this op.
______________________________________________________________________
-
+
### method `can_fuse`
@@ -1168,7 +1188,7 @@ Flatten operation cannot be fused since it must be performed over integer tensor
______________________________________________________________________
-
+
### method `q_impl`
@@ -1192,13 +1212,13 @@ Flatten the input integer encrypted tensor.
______________________________________________________________________
-
+
## class `QuantizedReduceSum`
ReduceSum with encrypted input.
-
+
### method `__init__`
@@ -1239,7 +1259,7 @@ Get the names of encrypted integer tensors that are used by this op.
______________________________________________________________________
-
+
### method `calibrate`
@@ -1259,7 +1279,7 @@ Create corresponding QuantizedArray for the output of the activation function.
______________________________________________________________________
-
+
### method `q_impl`
@@ -1285,7 +1305,7 @@ Sum the encrypted tensor's values along the given axes.
______________________________________________________________________
-
+
## class `QuantizedErf`
@@ -1303,7 +1323,7 @@ Get the names of encrypted integer tensors that are used by this op.
______________________________________________________________________
-
+
## class `QuantizedNot`
@@ -1321,13 +1341,13 @@ Get the names of encrypted integer tensors that are used by this op.
______________________________________________________________________
-
+
## class `QuantizedBrevitasQuant`
Brevitas uniform quantization with encrypted input.
-
+
### method `__init__`
@@ -1370,7 +1390,7 @@ Get the names of encrypted integer tensors that are used by this op.
______________________________________________________________________
-
+
### method `calibrate`
@@ -1390,7 +1410,7 @@ Create corresponding QuantizedArray for the output of Quantization function.
______________________________________________________________________
-
+
### method `q_impl`
@@ -1414,7 +1434,7 @@ Quantize values.
______________________________________________________________________
-
+
## class `QuantizedTranspose`
@@ -1434,7 +1454,7 @@ Get the names of encrypted integer tensors that are used by this op.
______________________________________________________________________
-
+
### method `can_fuse`
@@ -1452,7 +1472,7 @@ Transpose can not be fused since it must be performed over integer tensors as it
______________________________________________________________________
-
+
### method `q_impl`
@@ -1476,7 +1496,7 @@ Transpose the input integer encrypted tensor.
______________________________________________________________________
-
+
## class `QuantizedFloor`
@@ -1494,7 +1514,7 @@ Get the names of encrypted integer tensors that are used by this op.
______________________________________________________________________
-
+
## class `QuantizedMax`
@@ -1512,7 +1532,7 @@ Get the names of encrypted integer tensors that are used by this op.
______________________________________________________________________
-
+
## class `QuantizedMin`
@@ -1530,7 +1550,7 @@ Get the names of encrypted integer tensors that are used by this op.
______________________________________________________________________
-
+
## class `QuantizedNeg`
@@ -1548,7 +1568,7 @@ Get the names of encrypted integer tensors that are used by this op.
______________________________________________________________________
-
+
## class `QuantizedSign`
@@ -1566,7 +1586,7 @@ Get the names of encrypted integer tensors that are used by this op.
______________________________________________________________________
-
+
## class `QuantizedUnsqueeze`
@@ -1584,7 +1604,7 @@ Get the names of encrypted integer tensors that are used by this op.
______________________________________________________________________
-
+
### method `can_fuse`
@@ -1602,7 +1622,7 @@ Unsqueeze can not be fused since it must be performed over integer tensors as it
______________________________________________________________________
-
+
### method `q_impl`
@@ -1626,7 +1646,7 @@ Unsqueeze the input tensors on a given axis.
______________________________________________________________________
-
+
## class `QuantizedConcat`
@@ -1644,7 +1664,7 @@ Get the names of encrypted integer tensors that are used by this op.
______________________________________________________________________
-
+
### method `can_fuse`
@@ -1662,7 +1682,7 @@ Concatenation can not be fused since it must be performed over integer tensors a
______________________________________________________________________
-
+
### method `q_impl`
@@ -1686,7 +1706,7 @@ Concatenate the input tensors on a given axis.
______________________________________________________________________
-
+
## class `QuantizedSqueeze`
@@ -1704,7 +1724,7 @@ Get the names of encrypted integer tensors that are used by this op.
______________________________________________________________________
-
+
### method `can_fuse`
@@ -1722,7 +1742,7 @@ Squeeze can not be fused since it must be performed over integer tensors as it r
______________________________________________________________________
-
+
### method `q_impl`
@@ -1746,7 +1766,7 @@ Squeeze the input tensors on a given axis.
______________________________________________________________________
-
+
## class `ONNXShape`
@@ -1764,7 +1784,7 @@ Get the names of encrypted integer tensors that are used by this op.
______________________________________________________________________
-
+
### method `can_fuse`
@@ -1782,7 +1802,7 @@ This operation returns the shape of the tensor and thus can not be fused into a
______________________________________________________________________
-
+
### method `q_impl`
@@ -1795,7 +1815,7 @@ q_impl(
______________________________________________________________________
-
+
## class `ONNXConstantOfShape`
@@ -1813,7 +1833,7 @@ Get the names of encrypted integer tensors that are used by this op.
______________________________________________________________________
-
+
### method `can_fuse`
@@ -1831,7 +1851,7 @@ This operation returns a new encrypted tensor and thus can not be fused.
______________________________________________________________________
-
+
## class `ONNXGather`
@@ -1851,7 +1871,7 @@ Get the names of encrypted integer tensors that are used by this op.
______________________________________________________________________
-
+
### method `can_fuse`
@@ -1869,7 +1889,7 @@ This operation returns values from a tensor and thus can not be fused into a uni
______________________________________________________________________
-
+
### method `q_impl`
@@ -1882,7 +1902,7 @@ q_impl(
______________________________________________________________________
-
+
## class `ONNXSlice`
@@ -1900,7 +1920,7 @@ Get the names of encrypted integer tensors that are used by this op.
______________________________________________________________________
-
+
### method `can_fuse`
@@ -1918,7 +1938,7 @@ This operation returns values from a tensor and thus can not be fused into a uni
______________________________________________________________________
-
+
### method `q_impl`
@@ -1931,7 +1951,7 @@ q_impl(
______________________________________________________________________
-
+
## class `QuantizedExpand`
@@ -1949,7 +1969,7 @@ Get the names of encrypted integer tensors that are used by this op.
______________________________________________________________________
-
+
### method `can_fuse`
@@ -1967,7 +1987,7 @@ Unsqueeze can not be fused since it must be performed over integer tensors as it
______________________________________________________________________
-
+
### method `q_impl`
@@ -1991,7 +2011,7 @@ Expand the input tensor to a specified shape.
______________________________________________________________________
-
+
## class `QuantizedEqual`
@@ -1999,7 +2019,7 @@ Comparison operator ==.
Only supports comparison with a constant.
-
+
### method `__init__`
@@ -2026,13 +2046,13 @@ Get the names of encrypted integer tensors that are used by this op.
______________________________________________________________________
-
+
## class `QuantizedUnfold`
Quantized Unfold op.
-
+
### method `__init__`
@@ -2059,7 +2079,7 @@ Get the names of encrypted integer tensors that are used by this op.
______________________________________________________________________
-
+
### method `q_impl`
diff --git a/docs/references/api/concrete.ml.quantization.quantizers.md b/docs/references/api/concrete.ml.quantization.quantizers.md
index cdef46c2d..8b259c098 100644
--- a/docs/references/api/concrete.ml.quantization.quantizers.md
+++ b/docs/references/api/concrete.ml.quantization.quantizers.md
@@ -638,7 +638,7 @@ ______________________________________________________________________
### method `dequant`
```python
-dequant(qvalues: 'ndarray') → Union[ndarray, Tracer]
+dequant(qvalues: 'ndarray') → Union[float, ndarray, Tracer]
```
De-quantize values.
diff --git a/docs/references/api/concrete.ml.sklearn.base.md b/docs/references/api/concrete.ml.sklearn.base.md
index e7419b077..e98cd234c 100644
--- a/docs/references/api/concrete.ml.sklearn.base.md
+++ b/docs/references/api/concrete.ml.sklearn.base.md
@@ -14,7 +14,7 @@ Base classes for all estimators.
______________________________________________________________________
-
+
## class `BaseEstimator`
@@ -26,7 +26,7 @@ This class does not inherit from sklearn.base.BaseEstimator as it creates some c
- `_is_a_public_cml_model` (bool): Private attribute indicating if the class is a public model (as opposed to base or mixin classes).
-
+
### method `__init__`
@@ -84,7 +84,7 @@ Is None if the model is not fitted.
______________________________________________________________________
-
+
### method `check_model_is_compiled`
@@ -100,7 +100,7 @@ Check if the model is compiled.
______________________________________________________________________
-
+
### method `check_model_is_fitted`
@@ -116,7 +116,7 @@ Check if the model is fitted.
______________________________________________________________________
-
+
### method `compile`
@@ -150,7 +150,7 @@ Compile the model.
______________________________________________________________________
-
+
### method `dequantize_output`
@@ -172,7 +172,7 @@ This step ensures that the fit method has been called.
______________________________________________________________________
-
+
### method `dump`
@@ -188,7 +188,7 @@ Dump itself to a file.
______________________________________________________________________
-
+
### method `dump_dict`
@@ -204,7 +204,7 @@ Dump the object as a dict.
______________________________________________________________________
-
+
### method `dumps`
@@ -220,7 +220,7 @@ Dump itself to a string.
______________________________________________________________________
-
+
### method `fit`
@@ -243,7 +243,7 @@ The fitted estimator.
______________________________________________________________________
-
+
### method `fit_benchmark`
@@ -270,7 +270,7 @@ The Concrete ML and float equivalent fitted estimators.
______________________________________________________________________
-
+
### method `get_sklearn_params`
@@ -292,7 +292,7 @@ This method is used to instantiate a scikit-learn model using the Concrete ML mo
______________________________________________________________________
-
+
### classmethod `load_dict`
@@ -312,7 +312,7 @@ Load itself from a dict.
______________________________________________________________________
-
+
### method `post_processing`
@@ -336,7 +336,7 @@ For some simple models such a linear regression, there is no post-processing ste
______________________________________________________________________
-
+
### method `predict`
@@ -360,7 +360,7 @@ Predict values for X, in FHE or in the clear.
______________________________________________________________________
-
+
### method `quantize_input`
@@ -382,7 +382,7 @@ This step ensures that the fit method has been called.
______________________________________________________________________
-
+
## class `BaseClassifier`
@@ -390,7 +390,7 @@ Base class for linear and tree-based classifiers in Concrete ML.
This class inherits from BaseEstimator and modifies some of its methods in order to align them with classifier behaviors. This notably include applying a sigmoid/softmax post-processing to the predicted values as well as handling a mapping of classes in case they are not ordered.
-
+
### method `__init__`
@@ -472,7 +472,7 @@ Using this attribute is deprecated.
______________________________________________________________________
-
+
### method `check_model_is_compiled`
@@ -488,7 +488,7 @@ Check if the model is compiled.
______________________________________________________________________
-
+
### method `check_model_is_fitted`
@@ -504,7 +504,7 @@ Check if the model is fitted.
______________________________________________________________________
-
+
### method `compile`
@@ -538,7 +538,7 @@ Compile the model.
______________________________________________________________________
-
+
### method `dequantize_output`
@@ -560,7 +560,7 @@ This step ensures that the fit method has been called.
______________________________________________________________________
-
+
### method `dump`
@@ -576,7 +576,7 @@ Dump itself to a file.
______________________________________________________________________
-
+
### method `dump_dict`
@@ -592,7 +592,7 @@ Dump the object as a dict.
______________________________________________________________________
-
+
### method `dumps`
@@ -608,7 +608,7 @@ Dump itself to a string.
______________________________________________________________________
-
+
### method `fit`
@@ -618,7 +618,7 @@ fit(X: 'Data', y: 'Target', **fit_parameters)
______________________________________________________________________
-
+
### method `fit_benchmark`
@@ -645,7 +645,7 @@ The Concrete ML and float equivalent fitted estimators.
______________________________________________________________________
-
+
### method `get_sklearn_params`
@@ -667,7 +667,7 @@ This method is used to instantiate a scikit-learn model using the Concrete ML mo
______________________________________________________________________
-
+
### classmethod `load_dict`
@@ -687,7 +687,7 @@ Load itself from a dict.
______________________________________________________________________
-
+
### method `post_processing`
@@ -697,7 +697,7 @@ post_processing(y_preds: 'ndarray') → ndarray
______________________________________________________________________
-
+
### method `predict`
@@ -710,7 +710,7 @@ predict(
______________________________________________________________________
-
+
### method `predict_proba`
@@ -734,7 +734,7 @@ Predict class probabilities.
______________________________________________________________________
-
+
### method `quantize_input`
@@ -756,13 +756,13 @@ This step ensures that the fit method has been called.
______________________________________________________________________
-
+
## class `QuantizedTorchEstimatorMixin`
Mixin that provides quantization for a torch module and follows the Estimator API.
-
+
### method `__init__`
@@ -838,7 +838,7 @@ Get the output quantizers.
______________________________________________________________________
-
+
### method `check_model_is_compiled`
@@ -854,7 +854,7 @@ Check if the model is compiled.
______________________________________________________________________
-
+
### method `check_model_is_fitted`
@@ -870,7 +870,7 @@ Check if the model is fitted.
______________________________________________________________________
-
+
### method `compile`
@@ -888,7 +888,7 @@ compile(
______________________________________________________________________
-
+
### method `dequantize_output`
@@ -898,7 +898,7 @@ dequantize_output(*q_y_preds: 'ndarray') → ndarray
______________________________________________________________________
-
+
### method `dump`
@@ -914,7 +914,7 @@ Dump itself to a file.
______________________________________________________________________
-
+
### method `dump_dict`
@@ -930,7 +930,7 @@ Dump the object as a dict.
______________________________________________________________________
-
+
### method `dumps`
@@ -946,7 +946,7 @@ Dump itself to a string.
______________________________________________________________________
-
+
### method `fit`
@@ -971,7 +971,7 @@ The fitted estimator.
______________________________________________________________________
-
+
### method `fit_benchmark`
@@ -1002,7 +1002,7 @@ The Concrete ML and equivalent skorch fitted estimators.
______________________________________________________________________
-
+
### method `get_params`
@@ -1024,7 +1024,7 @@ This method is overloaded in order to make sure that auto-computed parameters ar
______________________________________________________________________
-
+
### method `get_sklearn_params`
@@ -1034,7 +1034,7 @@ get_sklearn_params(deep: 'bool' = True) → Dict
______________________________________________________________________
-
+
### classmethod `load_dict`
@@ -1054,7 +1054,7 @@ Load itself from a dict.
______________________________________________________________________
-
+
### method `post_processing`
@@ -1064,7 +1064,7 @@ post_processing(y_preds: 'ndarray') → ndarray
______________________________________________________________________
-
+
### method `predict`
@@ -1088,7 +1088,7 @@ Predict values for X, in FHE or in the clear.
______________________________________________________________________
-
+
### method `prune`
@@ -1116,7 +1116,7 @@ A new pruned copy of the Neural Network model.
______________________________________________________________________
-
+
### method `quantize_input`
@@ -1126,7 +1126,7 @@ quantize_input(X: 'ndarray') → ndarray
______________________________________________________________________
-
+
## class `BaseTreeEstimatorMixin`
@@ -1134,7 +1134,7 @@ Mixin class for tree-based estimators.
This class inherits from sklearn.base.BaseEstimator in order to have access to scikit-learn's `get_params` and `set_params` methods.
-
+
### method `__init__`
@@ -1196,7 +1196,7 @@ Is None if the model is not fitted.
______________________________________________________________________
-
+
### method `check_model_is_compiled`
@@ -1212,7 +1212,7 @@ Check if the model is compiled.
______________________________________________________________________
-
+
### method `check_model_is_fitted`
@@ -1228,7 +1228,7 @@ Check if the model is fitted.
______________________________________________________________________
-
+
### method `compile`
@@ -1238,7 +1238,7 @@ compile(*args, **kwargs) → Circuit
______________________________________________________________________
-
+
### method `dequantize_output`
@@ -1248,7 +1248,7 @@ dequantize_output(q_y_preds: 'ndarray') → ndarray
______________________________________________________________________
-
+
### method `dump`
@@ -1264,7 +1264,7 @@ Dump itself to a file.
______________________________________________________________________
-
+
### method `dump_dict`
@@ -1280,7 +1280,7 @@ Dump the object as a dict.
______________________________________________________________________
-
+
### method `dumps`
@@ -1296,7 +1296,7 @@ Dump itself to a string.
______________________________________________________________________
-
+
### method `fit`
@@ -1306,7 +1306,7 @@ fit(X: 'Data', y: 'Target', **fit_parameters)
______________________________________________________________________
-
+
### method `fit_benchmark`
@@ -1333,7 +1333,34 @@ The Concrete ML and float equivalent fitted estimators.
______________________________________________________________________
-
+
+
+### classmethod `from_sklearn_model`
+
+```python
+from_sklearn_model(
+ sklearn_model: 'BaseEstimator',
+ X: 'Optional[ndarray]' = None,
+ n_bits: 'int' = 10
+)
+```
+
+Build a FHE-compliant model using a fitted scikit-learn model.
+
+**Args:**
+
+- `sklearn_model` (sklearn.base.BaseEstimator): The fitted scikit-learn model to convert.
+- `X` (Optional\[Data\]): A representative set of input values used for computing quantization parameters, as a Numpy array, Torch tensor, Pandas DataFrame or List. This is usually the training data-set or a sub-set of it.
+- `n_bits` (int): Number of bits to quantize the model. If an int is passed for n_bits, the value will be used for quantizing inputs and weights. If a dict is passed, then it should contain "op_inputs" and "op_weights" as keys with corresponding number of quantization bits so that:
+ \- op_inputs : number of bits to quantize the input values
+ \- op_weights: number of bits to quantize the learned parameters Default to 8.
+
+**Returns:**
+The FHE-compliant fitted model.
+
+______________________________________________________________________
+
+
### method `get_sklearn_params`
@@ -1355,7 +1382,7 @@ This method is used to instantiate a scikit-learn model using the Concrete ML mo
______________________________________________________________________
-
+
### classmethod `load_dict`
@@ -1375,7 +1402,7 @@ Load itself from a dict.
______________________________________________________________________
-
+
### method `post_processing`
@@ -1385,7 +1412,7 @@ post_processing(y_preds: 'ndarray') → ndarray
______________________________________________________________________
-
+
### method `predict`
@@ -1398,7 +1425,7 @@ predict(
______________________________________________________________________
-
+
### method `quantize_input`
@@ -1408,7 +1435,7 @@ quantize_input(X: 'ndarray') → ndarray
______________________________________________________________________
-
+
## class `BaseTreeRegressorMixin`
@@ -1416,7 +1443,7 @@ Mixin class for tree-based regressors.
This class is used to create a tree-based regressor class that inherits from sklearn.base.RegressorMixin, which essentially gives access to scikit-learn's `score` method for regressors.
-
+
### method `__init__`
@@ -1478,7 +1505,7 @@ Is None if the model is not fitted.
______________________________________________________________________
-
+
### method `check_model_is_compiled`
@@ -1494,7 +1521,7 @@ Check if the model is compiled.
______________________________________________________________________
-
+
### method `check_model_is_fitted`
@@ -1510,7 +1537,7 @@ Check if the model is fitted.
______________________________________________________________________
-
+
### method `compile`
@@ -1520,7 +1547,7 @@ compile(*args, **kwargs) → Circuit
______________________________________________________________________
-
+
### method `dequantize_output`
@@ -1530,7 +1557,7 @@ dequantize_output(q_y_preds: 'ndarray') → ndarray
______________________________________________________________________
-
+
### method `dump`
@@ -1546,7 +1573,7 @@ Dump itself to a file.
______________________________________________________________________
-
+
### method `dump_dict`
@@ -1562,7 +1589,7 @@ Dump the object as a dict.
______________________________________________________________________
-
+
### method `dumps`
@@ -1578,7 +1605,7 @@ Dump itself to a string.
______________________________________________________________________
-
+
### method `fit`
@@ -1588,7 +1615,7 @@ fit(X: 'Data', y: 'Target', **fit_parameters)
______________________________________________________________________
-
+
### method `fit_benchmark`
@@ -1615,7 +1642,34 @@ The Concrete ML and float equivalent fitted estimators.
______________________________________________________________________
-
+
+
+### classmethod `from_sklearn_model`
+
+```python
+from_sklearn_model(
+ sklearn_model: 'BaseEstimator',
+ X: 'Optional[ndarray]' = None,
+ n_bits: 'int' = 10
+)
+```
+
+Build a FHE-compliant model using a fitted scikit-learn model.
+
+**Args:**
+
+- `sklearn_model` (sklearn.base.BaseEstimator): The fitted scikit-learn model to convert.
+- `X` (Optional\[Data\]): A representative set of input values used for computing quantization parameters, as a Numpy array, Torch tensor, Pandas DataFrame or List. This is usually the training data-set or a sub-set of it.
+- `n_bits` (int): Number of bits to quantize the model. If an int is passed for n_bits, the value will be used for quantizing inputs and weights. If a dict is passed, then it should contain "op_inputs" and "op_weights" as keys with corresponding number of quantization bits so that:
+ \- op_inputs : number of bits to quantize the input values
+ \- op_weights: number of bits to quantize the learned parameters Default to 8.
+
+**Returns:**
+The FHE-compliant fitted model.
+
+______________________________________________________________________
+
+
### method `get_sklearn_params`
@@ -1637,7 +1691,7 @@ This method is used to instantiate a scikit-learn model using the Concrete ML mo
______________________________________________________________________
-
+
### classmethod `load_dict`
@@ -1657,7 +1711,7 @@ Load itself from a dict.
______________________________________________________________________
-
+
### method `post_processing`
@@ -1667,7 +1721,7 @@ post_processing(y_preds: 'ndarray') → ndarray
______________________________________________________________________
-
+
### method `predict`
@@ -1680,7 +1734,7 @@ predict(
______________________________________________________________________
-
+
### method `quantize_input`
@@ -1690,7 +1744,7 @@ quantize_input(X: 'ndarray') → ndarray
______________________________________________________________________
-
+
## class `BaseTreeClassifierMixin`
@@ -1700,7 +1754,7 @@ This class is used to create a tree-based classifier class that inherits from sk
Additionally, this class adjusts some of the tree-based base class's methods in order to make them compliant with classification workflows.
-
+
### method `__init__`
@@ -1786,7 +1840,7 @@ Using this attribute is deprecated.
______________________________________________________________________
-
+
### method `check_model_is_compiled`
@@ -1802,7 +1856,7 @@ Check if the model is compiled.
______________________________________________________________________
-
+
### method `check_model_is_fitted`
@@ -1818,7 +1872,7 @@ Check if the model is fitted.
______________________________________________________________________
-
+
### method `compile`
@@ -1828,7 +1882,7 @@ compile(*args, **kwargs) → Circuit
______________________________________________________________________
-
+
### method `dequantize_output`
@@ -1838,7 +1892,7 @@ dequantize_output(q_y_preds: 'ndarray') → ndarray
______________________________________________________________________
-
+
### method `dump`
@@ -1854,7 +1908,7 @@ Dump itself to a file.
______________________________________________________________________
-
+
### method `dump_dict`
@@ -1870,7 +1924,7 @@ Dump the object as a dict.
______________________________________________________________________
-
+
### method `dumps`
@@ -1886,7 +1940,7 @@ Dump itself to a string.
______________________________________________________________________
-
+
### method `fit`
@@ -1896,7 +1950,7 @@ fit(X: 'Data', y: 'Target', **fit_parameters)
______________________________________________________________________
-
+
### method `fit_benchmark`
@@ -1923,7 +1977,34 @@ The Concrete ML and float equivalent fitted estimators.
______________________________________________________________________
-
+
+
+### classmethod `from_sklearn_model`
+
+```python
+from_sklearn_model(
+ sklearn_model: 'BaseEstimator',
+ X: 'Optional[ndarray]' = None,
+ n_bits: 'int' = 10
+)
+```
+
+Build a FHE-compliant model using a fitted scikit-learn model.
+
+**Args:**
+
+- `sklearn_model` (sklearn.base.BaseEstimator): The fitted scikit-learn model to convert.
+- `X` (Optional\[Data\]): A representative set of input values used for computing quantization parameters, as a Numpy array, Torch tensor, Pandas DataFrame or List. This is usually the training data-set or a sub-set of it.
+- `n_bits` (int): Number of bits to quantize the model. If an int is passed for n_bits, the value will be used for quantizing inputs and weights. If a dict is passed, then it should contain "op_inputs" and "op_weights" as keys with corresponding number of quantization bits so that:
+ \- op_inputs : number of bits to quantize the input values
+ \- op_weights: number of bits to quantize the learned parameters Default to 8.
+
+**Returns:**
+The FHE-compliant fitted model.
+
+______________________________________________________________________
+
+
### method `get_sklearn_params`
@@ -1945,7 +2026,7 @@ This method is used to instantiate a scikit-learn model using the Concrete ML mo
______________________________________________________________________
-
+
### classmethod `load_dict`
@@ -1965,7 +2046,7 @@ Load itself from a dict.
______________________________________________________________________
-
+
### method `post_processing`
@@ -1975,7 +2056,7 @@ post_processing(y_preds: 'ndarray') → ndarray
______________________________________________________________________
-
+
### method `predict`
@@ -1988,7 +2069,7 @@ predict(
______________________________________________________________________
-
+
### method `predict_proba`
@@ -2012,7 +2093,7 @@ Predict class probabilities.
______________________________________________________________________
-
+
### method `quantize_input`
@@ -2022,7 +2103,7 @@ quantize_input(X: 'ndarray') → ndarray
______________________________________________________________________
-
+
## class `SklearnLinearModelMixin`
@@ -2030,7 +2111,7 @@ A Mixin class for sklearn linear models with FHE.
This class inherits from sklearn.base.BaseEstimator in order to have access to scikit-learn's `get_params` and `set_params` methods.
-
+
### method `__init__`
@@ -2092,7 +2173,7 @@ Is None if the model is not fitted.
______________________________________________________________________
-
+
### method `check_model_is_compiled`
@@ -2108,7 +2189,7 @@ Check if the model is compiled.
______________________________________________________________________
-
+
### method `check_model_is_fitted`
@@ -2124,7 +2205,7 @@ Check if the model is fitted.
______________________________________________________________________
-
+
### method `compile`
@@ -2158,7 +2239,7 @@ Compile the model.
______________________________________________________________________
-
+
### method `dequantize_output`
@@ -2168,7 +2249,7 @@ dequantize_output(q_y_preds: 'ndarray') → ndarray
______________________________________________________________________
-
+
### method `dump`
@@ -2184,7 +2265,7 @@ Dump itself to a file.
______________________________________________________________________
-
+
### method `dump_dict`
@@ -2200,7 +2281,7 @@ Dump the object as a dict.
______________________________________________________________________
-
+
### method `dumps`
@@ -2216,7 +2297,7 @@ Dump itself to a string.
______________________________________________________________________
-
+
### method `fit`
@@ -2226,7 +2307,7 @@ fit(X: 'Data', y: 'Target', **fit_parameters)
______________________________________________________________________
-
+
### method `fit_benchmark`
@@ -2253,7 +2334,7 @@ The Concrete ML and float equivalent fitted estimators.
______________________________________________________________________
-
+
### classmethod `from_sklearn_model`
@@ -2280,7 +2361,7 @@ The FHE-compliant fitted model.
______________________________________________________________________
-
+
### method `get_sklearn_params`
@@ -2302,7 +2383,7 @@ This method is used to instantiate a scikit-learn model using the Concrete ML mo
______________________________________________________________________
-
+
### classmethod `load_dict`
@@ -2322,7 +2403,7 @@ Load itself from a dict.
______________________________________________________________________
-
+
### method `post_processing`
@@ -2346,7 +2427,7 @@ For some simple models such a linear regression, there is no post-processing ste
______________________________________________________________________
-
+
### method `predict`
@@ -2370,7 +2451,7 @@ Predict values for X, in FHE or in the clear.
______________________________________________________________________
-
+
### method `quantize_input`
@@ -2380,7 +2461,7 @@ quantize_input(X: 'ndarray') → ndarray
______________________________________________________________________
-
+
## class `SklearnLinearRegressorMixin`
@@ -2388,7 +2469,7 @@ A Mixin class for sklearn linear regressors with FHE.
This class is used to create a linear regressor class that inherits from sklearn.base.RegressorMixin, which essentially gives access to scikit-learn's `score` method for regressors.
-
+
### method `__init__`
@@ -2450,7 +2531,7 @@ Is None if the model is not fitted.
______________________________________________________________________
-
+
### method `check_model_is_compiled`
@@ -2466,7 +2547,7 @@ Check if the model is compiled.
______________________________________________________________________
-
+
### method `check_model_is_fitted`
@@ -2482,7 +2563,7 @@ Check if the model is fitted.
______________________________________________________________________
-
+
### method `compile`
@@ -2516,7 +2597,7 @@ Compile the model.
______________________________________________________________________
-
+
### method `dequantize_output`
@@ -2526,7 +2607,7 @@ dequantize_output(q_y_preds: 'ndarray') → ndarray
______________________________________________________________________
-
+
### method `dump`
@@ -2542,7 +2623,7 @@ Dump itself to a file.
______________________________________________________________________
-
+
### method `dump_dict`
@@ -2558,7 +2639,7 @@ Dump the object as a dict.
______________________________________________________________________
-
+
### method `dumps`
@@ -2574,7 +2655,7 @@ Dump itself to a string.
______________________________________________________________________
-
+
### method `fit`
@@ -2584,7 +2665,7 @@ fit(X: 'Data', y: 'Target', **fit_parameters)
______________________________________________________________________
-
+
### method `fit_benchmark`
@@ -2611,7 +2692,7 @@ The Concrete ML and float equivalent fitted estimators.
______________________________________________________________________
-
+
### classmethod `from_sklearn_model`
@@ -2638,7 +2719,7 @@ The FHE-compliant fitted model.
______________________________________________________________________
-
+
### method `get_sklearn_params`
@@ -2660,7 +2741,7 @@ This method is used to instantiate a scikit-learn model using the Concrete ML mo
______________________________________________________________________
-
+
### classmethod `load_dict`
@@ -2680,7 +2761,7 @@ Load itself from a dict.
______________________________________________________________________
-
+
### method `post_processing`
@@ -2704,7 +2785,7 @@ For some simple models such a linear regression, there is no post-processing ste
______________________________________________________________________
-
+
### method `predict`
@@ -2728,7 +2809,7 @@ Predict values for X, in FHE or in the clear.
______________________________________________________________________
-
+
### method `quantize_input`
@@ -2738,7 +2819,7 @@ quantize_input(X: 'ndarray') → ndarray
______________________________________________________________________
-
+
## class `SklearnLinearClassifierMixin`
@@ -2748,7 +2829,7 @@ This class is used to create a linear classifier class that inherits from sklear
Additionally, this class adjusts some of the tree-based base class's methods in order to make them compliant with classification workflows.
-
+
### method `__init__`
@@ -2834,7 +2915,7 @@ Using this attribute is deprecated.
______________________________________________________________________
-
+
### method `check_model_is_compiled`
@@ -2850,7 +2931,7 @@ Check if the model is compiled.
______________________________________________________________________
-
+
### method `check_model_is_fitted`
@@ -2866,7 +2947,7 @@ Check if the model is fitted.
______________________________________________________________________
-
+
### method `compile`
@@ -2900,7 +2981,7 @@ Compile the model.
______________________________________________________________________
-
+
### method `decision_function`
@@ -2924,7 +3005,7 @@ Predict confidence scores.
______________________________________________________________________
-
+
### method `dequantize_output`
@@ -2934,7 +3015,7 @@ dequantize_output(q_y_preds: 'ndarray') → ndarray
______________________________________________________________________
-
+
### method `dump`
@@ -2950,7 +3031,7 @@ Dump itself to a file.
______________________________________________________________________
-
+
### method `dump_dict`
@@ -2966,7 +3047,7 @@ Dump the object as a dict.
______________________________________________________________________
-
+
### method `dumps`
@@ -2982,7 +3063,7 @@ Dump itself to a string.
______________________________________________________________________
-
+
### method `fit`
@@ -2992,7 +3073,7 @@ fit(X: 'Data', y: 'Target', **fit_parameters)
______________________________________________________________________
-
+
### method `fit_benchmark`
@@ -3019,7 +3100,7 @@ The Concrete ML and float equivalent fitted estimators.
______________________________________________________________________
-
+
### classmethod `from_sklearn_model`
@@ -3046,7 +3127,7 @@ The FHE-compliant fitted model.
______________________________________________________________________
-
+
### method `get_sklearn_params`
@@ -3068,7 +3149,7 @@ This method is used to instantiate a scikit-learn model using the Concrete ML mo
______________________________________________________________________
-
+
### classmethod `load_dict`
@@ -3088,7 +3169,7 @@ Load itself from a dict.
______________________________________________________________________
-
+
### method `post_processing`
@@ -3098,7 +3179,7 @@ post_processing(y_preds: 'ndarray') → ndarray
______________________________________________________________________
-
+
### method `predict`
@@ -3111,7 +3192,7 @@ predict(
______________________________________________________________________
-
+
### method `predict_proba`
@@ -3124,7 +3205,7 @@ predict_proba(
______________________________________________________________________
-
+
### method `quantize_input`
@@ -3134,7 +3215,7 @@ quantize_input(X: 'ndarray') → ndarray
______________________________________________________________________
-
+
## class `SklearnSGDRegressorMixin`
@@ -3142,7 +3223,7 @@ A Mixin class for sklearn SGD regressors with FHE.
This class is used to create a SGD regressor class what can be exported to ONNX using Hummingbird.
-
+
### method `__init__`
@@ -3204,7 +3285,7 @@ Is None if the model is not fitted.
______________________________________________________________________
-
+
### method `check_model_is_compiled`
@@ -3220,7 +3301,7 @@ Check if the model is compiled.
______________________________________________________________________
-
+
### method `check_model_is_fitted`
@@ -3236,7 +3317,7 @@ Check if the model is fitted.
______________________________________________________________________
-
+
### method `compile`
@@ -3270,7 +3351,7 @@ Compile the model.
______________________________________________________________________
-
+
### method `dequantize_output`
@@ -3280,7 +3361,7 @@ dequantize_output(q_y_preds: 'ndarray') → ndarray
______________________________________________________________________
-
+
### method `dump`
@@ -3296,7 +3377,7 @@ Dump itself to a file.
______________________________________________________________________
-
+
### method `dump_dict`
@@ -3312,7 +3393,7 @@ Dump the object as a dict.
______________________________________________________________________
-
+
### method `dumps`
@@ -3328,7 +3409,7 @@ Dump itself to a string.
______________________________________________________________________
-
+
### method `fit`
@@ -3338,7 +3419,7 @@ fit(X: 'Data', y: 'Target', **fit_parameters)
______________________________________________________________________
-
+
### method `fit_benchmark`
@@ -3365,7 +3446,7 @@ The Concrete ML and float equivalent fitted estimators.
______________________________________________________________________
-
+
### classmethod `from_sklearn_model`
@@ -3392,7 +3473,7 @@ The FHE-compliant fitted model.
______________________________________________________________________
-
+
### method `get_sklearn_params`
@@ -3414,7 +3495,7 @@ This method is used to instantiate a scikit-learn model using the Concrete ML mo
______________________________________________________________________
-
+
### classmethod `load_dict`
@@ -3434,7 +3515,7 @@ Load itself from a dict.
______________________________________________________________________
-
+
### method `post_processing`
@@ -3458,7 +3539,7 @@ For some simple models such a linear regression, there is no post-processing ste
______________________________________________________________________
-
+
### method `predict`
@@ -3482,7 +3563,7 @@ Predict values for X, in FHE or in the clear.
______________________________________________________________________
-
+
### method `quantize_input`
@@ -3492,7 +3573,7 @@ quantize_input(X: 'ndarray') → ndarray
______________________________________________________________________
-
+
## class `SklearnSGDClassifierMixin`
@@ -3500,7 +3581,7 @@ A Mixin class for sklearn SGD classifiers with FHE.
This class is used to create a SGD classifier class what can be exported to ONNX using Hummingbird.
-
+
### method `__init__`
@@ -3586,7 +3667,7 @@ Using this attribute is deprecated.
______________________________________________________________________
-
+
### method `check_model_is_compiled`
@@ -3602,7 +3683,7 @@ Check if the model is compiled.
______________________________________________________________________
-
+
### method `check_model_is_fitted`
@@ -3618,7 +3699,7 @@ Check if the model is fitted.
______________________________________________________________________
-
+
### method `compile`
@@ -3652,7 +3733,7 @@ Compile the model.
______________________________________________________________________
-
+
### method `decision_function`
@@ -3676,7 +3757,7 @@ Predict confidence scores.
______________________________________________________________________
-
+
### method `dequantize_output`
@@ -3686,7 +3767,7 @@ dequantize_output(q_y_preds: 'ndarray') → ndarray
______________________________________________________________________
-
+
### method `dump`
@@ -3702,7 +3783,7 @@ Dump itself to a file.
______________________________________________________________________
-
+
### method `dump_dict`
@@ -3718,7 +3799,7 @@ Dump the object as a dict.
______________________________________________________________________
-
+
### method `dumps`
@@ -3734,7 +3815,7 @@ Dump itself to a string.
______________________________________________________________________
-
+
### method `fit`
@@ -3744,7 +3825,7 @@ fit(X: 'Data', y: 'Target', **fit_parameters)
______________________________________________________________________
-
+
### method `fit_benchmark`
@@ -3771,7 +3852,7 @@ The Concrete ML and float equivalent fitted estimators.
______________________________________________________________________
-
+
### classmethod `from_sklearn_model`
@@ -3798,7 +3879,7 @@ The FHE-compliant fitted model.
______________________________________________________________________
-
+
### method `get_sklearn_params`
@@ -3820,7 +3901,7 @@ This method is used to instantiate a scikit-learn model using the Concrete ML mo
______________________________________________________________________
-
+
### classmethod `load_dict`
@@ -3840,7 +3921,7 @@ Load itself from a dict.
______________________________________________________________________
-
+
### method `post_processing`
@@ -3850,7 +3931,7 @@ post_processing(y_preds: 'ndarray') → ndarray
______________________________________________________________________
-
+
### method `predict`
@@ -3863,7 +3944,7 @@ predict(
______________________________________________________________________
-
+
### method `predict_proba`
@@ -3876,7 +3957,7 @@ predict_proba(
______________________________________________________________________
-
+
### method `quantize_input`
@@ -3886,7 +3967,7 @@ quantize_input(X: 'ndarray') → ndarray
______________________________________________________________________
-
+
## class `SklearnKNeighborsMixin`
@@ -3894,7 +3975,7 @@ A Mixin class for sklearn KNeighbors models with FHE.
This class inherits from sklearn.base.BaseEstimator in order to have access to scikit-learn's `get_params` and `set_params` methods.
-
+
### method `__init__`
@@ -3954,7 +4035,7 @@ Is None if the model is not fitted.
______________________________________________________________________
-
+
### method `check_model_is_compiled`
@@ -3970,7 +4051,7 @@ Check if the model is compiled.
______________________________________________________________________
-
+
### method `check_model_is_fitted`
@@ -3986,7 +4067,7 @@ Check if the model is fitted.
______________________________________________________________________
-
+
### method `compile`
@@ -4020,7 +4101,7 @@ Compile the model.
______________________________________________________________________
-
+
### method `dequantize_output`
@@ -4030,7 +4111,7 @@ dequantize_output(q_y_preds: 'ndarray') → ndarray
______________________________________________________________________
-
+
### method `dump`
@@ -4046,7 +4127,7 @@ Dump itself to a file.
______________________________________________________________________
-
+
### method `dump_dict`
@@ -4062,7 +4143,7 @@ Dump the object as a dict.
______________________________________________________________________
-
+
### method `dumps`
@@ -4078,7 +4159,7 @@ Dump itself to a string.
______________________________________________________________________
-
+
### method `fit`
@@ -4088,7 +4169,7 @@ fit(X: 'Data', y: 'Target', **fit_parameters)
______________________________________________________________________
-
+
### method `fit_benchmark`
@@ -4115,7 +4196,7 @@ The Concrete ML and float equivalent fitted estimators.
______________________________________________________________________
-
+
### method `get_sklearn_params`
@@ -4137,7 +4218,7 @@ This method is used to instantiate a scikit-learn model using the Concrete ML mo
______________________________________________________________________
-
+
### method `get_topk_labels`
@@ -4161,7 +4242,7 @@ Return the K-nearest labels of each point.
______________________________________________________________________
-
+
### classmethod `load_dict`
@@ -4181,7 +4262,7 @@ Load itself from a dict.
______________________________________________________________________
-
+
### method `majority_vote`
@@ -4201,7 +4282,7 @@ Determine the most common class among nearest neighborsfor each query.
______________________________________________________________________
-
+
### method `post_processing`
@@ -4223,7 +4304,7 @@ For KNN, the de-quantization step is not required. Because \_inference returns t
______________________________________________________________________
-
+
### method `predict`
@@ -4236,7 +4317,7 @@ predict(
______________________________________________________________________
-
+
### method `quantize_input`
@@ -4246,7 +4327,7 @@ quantize_input(X: 'ndarray') → ndarray
______________________________________________________________________
-
+
## class `SklearnKNeighborsClassifierMixin`
@@ -4254,7 +4335,7 @@ A Mixin class for sklearn KNeighbors classifiers with FHE.
This class is used to create a KNeighbors classifier class that inherits from SklearnKNeighborsMixin and sklearn.base.ClassifierMixin. By inheriting from sklearn.base.ClassifierMixin, it allows this class to be recognized as a classifier."
-
+
### method `__init__`
@@ -4314,7 +4395,7 @@ Is None if the model is not fitted.
______________________________________________________________________
-
+
### method `check_model_is_compiled`
@@ -4330,7 +4411,7 @@ Check if the model is compiled.
______________________________________________________________________
-
+
### method `check_model_is_fitted`
@@ -4346,7 +4427,7 @@ Check if the model is fitted.
______________________________________________________________________
-
+
### method `compile`
@@ -4380,7 +4461,7 @@ Compile the model.
______________________________________________________________________
-
+
### method `dequantize_output`
@@ -4390,7 +4471,7 @@ dequantize_output(q_y_preds: 'ndarray') → ndarray
______________________________________________________________________
-
+
### method `dump`
@@ -4406,7 +4487,7 @@ Dump itself to a file.
______________________________________________________________________
-
+
### method `dump_dict`
@@ -4422,7 +4503,7 @@ Dump the object as a dict.
______________________________________________________________________
-
+
### method `dumps`
@@ -4438,7 +4519,7 @@ Dump itself to a string.
______________________________________________________________________
-
+
### method `fit`
@@ -4448,7 +4529,7 @@ fit(X: 'Data', y: 'Target', **fit_parameters)
______________________________________________________________________
-
+
### method `fit_benchmark`
@@ -4475,7 +4556,7 @@ The Concrete ML and float equivalent fitted estimators.
______________________________________________________________________
-
+
### method `get_sklearn_params`
@@ -4497,7 +4578,7 @@ This method is used to instantiate a scikit-learn model using the Concrete ML mo
______________________________________________________________________
-
+
### method `get_topk_labels`
@@ -4521,7 +4602,7 @@ Return the K-nearest labels of each point.
______________________________________________________________________
-
+
### classmethod `load_dict`
@@ -4541,7 +4622,7 @@ Load itself from a dict.
______________________________________________________________________
-
+
### method `majority_vote`
@@ -4561,7 +4642,7 @@ Determine the most common class among nearest neighborsfor each query.
______________________________________________________________________
-
+
### method `post_processing`
@@ -4583,7 +4664,7 @@ For KNN, the de-quantization step is not required. Because \_inference returns t
______________________________________________________________________
-
+
### method `predict`
@@ -4596,7 +4677,7 @@ predict(
______________________________________________________________________
-
+
### method `quantize_input`
diff --git a/docs/references/api/concrete.ml.sklearn.glm.md b/docs/references/api/concrete.ml.sklearn.glm.md
index ac61eb76e..fe8e7f200 100644
--- a/docs/references/api/concrete.ml.sklearn.glm.md
+++ b/docs/references/api/concrete.ml.sklearn.glm.md
@@ -22,7 +22,7 @@ A Poisson regression model with FHE.
For more details on PoissonRegressor please refer to the scikit-learn documentation: https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.PoissonRegressor.html
-
+
### method `__init__`
@@ -31,7 +31,6 @@ __init__(
n_bits: 'Union[int, dict]' = 8,
alpha: 'float' = 1.0,
fit_intercept: 'bool' = True,
- solver: 'str' = 'lbfgs',
max_iter: 'int' = 100,
tol: 'float' = 0.0001,
warm_start: 'bool' = False,
@@ -128,7 +127,7 @@ predict(
______________________________________________________________________
-
+
## class `GammaRegressor`
@@ -142,7 +141,7 @@ A Gamma regression model with FHE.
For more details on GammaRegressor please refer to the scikit-learn documentation: https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.GammaRegressor.html
-
+
### method `__init__`
@@ -151,7 +150,6 @@ __init__(
n_bits: 'Union[int, dict]' = 8,
alpha: 'float' = 1.0,
fit_intercept: 'bool' = True,
- solver: 'str' = 'lbfgs',
max_iter: 'int' = 100,
tol: 'float' = 0.0001,
warm_start: 'bool' = False,
@@ -248,7 +246,7 @@ predict(
______________________________________________________________________
-
+
## class `TweedieRegressor`
@@ -262,7 +260,7 @@ A Tweedie regression model with FHE.
For more details on TweedieRegressor please refer to the scikit-learn documentation: https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.TweedieRegressor.html
-
+
### method `__init__`
@@ -273,8 +271,7 @@ __init__(
alpha: 'float' = 1.0,
fit_intercept: 'bool' = True,
link: 'str' = 'auto',
- max_iter: 'int' = 1000,
- solver: 'str' = 'lbfgs',
+ max_iter: 'int' = 100,
tol: 'float' = 0.0001,
warm_start: 'bool' = False,
verbose: 'int' = 0
@@ -327,7 +324,7 @@ Is None if the model is not fitted.
______________________________________________________________________
-
+
### method `dump_dict`
@@ -337,7 +334,7 @@ dump_dict() → Dict
______________________________________________________________________
-
+
### classmethod `load_dict`
diff --git a/docs/references/api/concrete.ml.sklearn.linear_model.md b/docs/references/api/concrete.ml.sklearn.linear_model.md
index 925e3bfee..986a2d4b5 100644
--- a/docs/references/api/concrete.ml.sklearn.linear_model.md
+++ b/docs/references/api/concrete.ml.sklearn.linear_model.md
@@ -8,7 +8,7 @@ Implement sklearn linear model.
______________________________________________________________________
-
+
## class `LinearRegression`
@@ -22,12 +22,19 @@ A linear regression model with FHE.
For more details on LinearRegression please refer to the scikit-learn documentation: https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LinearRegression.html
-
+
### method `__init__`
```python
-__init__(n_bits=8, fit_intercept=True, copy_X=True, n_jobs=None, positive=False)
+__init__(
+ n_bits=8,
+ fit_intercept=True,
+ normalize='deprecated',
+ copy_X=True,
+ n_jobs=None,
+ positive=False
+)
```
______________________________________________________________________
@@ -76,7 +83,7 @@ Is None if the model is not fitted.
______________________________________________________________________
-
+
### method `dump_dict`
@@ -86,7 +93,7 @@ dump_dict() → Dict[str, Any]
______________________________________________________________________
-
+
### classmethod `load_dict`
@@ -96,7 +103,7 @@ load_dict(metadata: Dict)
______________________________________________________________________
-
+
## class `SGDClassifier`
@@ -113,7 +120,7 @@ An FHE linear classifier model fitted with stochastic gradient descent.
For more details on SGDClassifier please refer to the scikit-learn documentation: https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.SGDClassifier.html
-
+
### method `__init__`
@@ -216,7 +223,7 @@ Using this attribute is deprecated.
______________________________________________________________________
-
+
### method `dump_dict`
@@ -226,7 +233,7 @@ dump_dict() → Dict[str, Any]
______________________________________________________________________
-
+
### method `fit`
@@ -267,7 +274,7 @@ The fitted estimator.
______________________________________________________________________
-
+
### method `get_sklearn_params`
@@ -277,7 +284,7 @@ get_sklearn_params(deep: bool = True) → dict
______________________________________________________________________
-
+
### classmethod `load_dict`
@@ -287,12 +294,17 @@ load_dict(metadata: Dict)
______________________________________________________________________
-
+
### method `partial_fit`
```python
-partial_fit(X: ndarray, y: ndarray, fhe: Optional[str, FheMode] = None)
+partial_fit(
+ X: ndarray,
+ y: ndarray,
+ fhe: Optional[str, FheMode] = None,
+ classes=None
+)
```
Fit SGDClassifier for a single iteration.
@@ -304,6 +316,7 @@ This function does one iteration of SGD training. Looping n_times over this func
- `X` (Data): The training data, as a Numpy array, Torch tensor, Pandas DataFrame or List.
- `y` (Target): The target data, as a Numpy array, Torch tensor, Pandas DataFrame, Pandas Series or List.
- `fhe` (Optional\[Union\[str, FheMode\]\]): The mode to use for FHE training. Can be FheMode.DISABLE for Concrete ML Python (quantized) training, FheMode.SIMULATE for FHE simulation and FheMode.EXECUTE for actual FHE execution. Can also be the string representation of any of these values. If None, training is done in floating points in the clear through scikit-learn. Default to None.
+- `classes` (Optional\[numpy.ndarray\]): The classes in the dataset. It needs to be provided in the first call to `partial_fit`. If provided in following calls it should match the classes provided in the first call
**Raises:**
@@ -311,7 +324,7 @@ This function does one iteration of SGD training. Looping n_times over this func
______________________________________________________________________
-
+
### method `post_processing`
@@ -347,7 +360,7 @@ The justification for the formula in the loss="modified_huber" case is in the ap
______________________________________________________________________
-
+
## class `SGDRegressor`
@@ -361,7 +374,7 @@ An FHE linear regression model fitted with stochastic gradient descent.
For more details on SGDRegressor please refer to the scikit-learn documentation: https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.SGDRegressor.html
-
+
### method `__init__`
@@ -436,7 +449,7 @@ Is None if the model is not fitted.
______________________________________________________________________
-
+
### method `dump_dict`
@@ -446,7 +459,7 @@ dump_dict() → Dict[str, Any]
______________________________________________________________________
-
+
### classmethod `load_dict`
@@ -456,7 +469,7 @@ load_dict(metadata: Dict)
______________________________________________________________________
-
+
## class `ElasticNet`
@@ -470,7 +483,7 @@ An ElasticNet regression model with FHE.
For more details on ElasticNet please refer to the scikit-learn documentation: https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.ElasticNet.html
-
+
### method `__init__`
@@ -480,6 +493,7 @@ __init__(
alpha=1.0,
l1_ratio=0.5,
fit_intercept=True,
+ normalize='deprecated',
precompute=False,
max_iter=1000,
copy_X=True,
@@ -537,7 +551,7 @@ Is None if the model is not fitted.
______________________________________________________________________
-
+
### method `dump_dict`
@@ -547,7 +561,7 @@ dump_dict() → Dict[str, Any]
______________________________________________________________________
-
+
### classmethod `load_dict`
@@ -557,7 +571,7 @@ load_dict(metadata: Dict)
______________________________________________________________________
-
+
## class `Lasso`
@@ -571,7 +585,7 @@ A Lasso regression model with FHE.
For more details on Lasso please refer to the scikit-learn documentation: https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.Lasso.html
-
+
### method `__init__`
@@ -580,6 +594,7 @@ __init__(
n_bits=8,
alpha: float = 1.0,
fit_intercept=True,
+ normalize='deprecated',
precompute=False,
copy_X=True,
max_iter=1000,
@@ -637,7 +652,7 @@ Is None if the model is not fitted.
______________________________________________________________________
-
+
### method `dump_dict`
@@ -647,7 +662,7 @@ dump_dict() → Dict[str, Any]
______________________________________________________________________
-
+
### classmethod `load_dict`
@@ -657,7 +672,7 @@ load_dict(metadata: Dict)
______________________________________________________________________
-
+
## class `Ridge`
@@ -671,7 +686,7 @@ A Ridge regression model with FHE.
For more details on Ridge please refer to the scikit-learn documentation: https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.Ridge.html
-
+
### method `__init__`
@@ -680,6 +695,7 @@ __init__(
n_bits=8,
alpha: float = 1.0,
fit_intercept=True,
+ normalize='deprecated',
copy_X=True,
max_iter=None,
tol=0.001,
@@ -735,7 +751,7 @@ Is None if the model is not fitted.
______________________________________________________________________
-
+
### method `dump_dict`
@@ -745,7 +761,7 @@ dump_dict() → Dict[str, Any]
______________________________________________________________________
-
+
### classmethod `load_dict`
@@ -755,7 +771,7 @@ load_dict(metadata: Dict)
______________________________________________________________________
-
+
## class `LogisticRegression`
@@ -769,7 +785,7 @@ A logistic regression model with FHE.
For more details on LogisticRegression please refer to the scikit-learn documentation: https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html
-
+
### method `__init__`
@@ -864,7 +880,7 @@ Using this attribute is deprecated.
______________________________________________________________________
-
+
### method `dump_dict`
@@ -874,7 +890,7 @@ dump_dict() → Dict[str, Any]
______________________________________________________________________
-
+
### classmethod `load_dict`
diff --git a/docs/references/api/concrete.ml.sklearn.rf.md b/docs/references/api/concrete.ml.sklearn.rf.md
index 22580f15e..b22f2d709 100644
--- a/docs/references/api/concrete.ml.sklearn.rf.md
+++ b/docs/references/api/concrete.ml.sklearn.rf.md
@@ -165,7 +165,7 @@ __init__(
min_samples_split=2,
min_samples_leaf=1,
min_weight_fraction_leaf=0.0,
- max_features='sqrt',
+ max_features=1.0,
max_leaf_nodes=None,
min_impurity_decrease=0.0,
bootstrap=True,
diff --git a/docs/references/api/concrete.ml.sklearn.tree_to_numpy.md b/docs/references/api/concrete.ml.sklearn.tree_to_numpy.md
index 16746f52b..ceb7f576a 100644
--- a/docs/references/api/concrete.ml.sklearn.tree_to_numpy.md
+++ b/docs/references/api/concrete.ml.sklearn.tree_to_numpy.md
@@ -15,19 +15,19 @@ Implements the conversion of a tree model to a numpy function.
______________________________________________________________________
-
+
## function `get_onnx_model`
```python
-get_onnx_model(model: Callable, x: ndarray, framework: str) → ModelProto
+get_onnx_model(model, x: ndarray, framework: str) → ModelProto
```
Create ONNX model with Hummingbird convert method.
**Args:**
-- `model` (Callable): The tree model to convert.
+- `model`: The tree model to convert.
- `x` (numpy.ndarray): Dataset used to trace the tree inference and convert the model to ONNX.
- `framework` (str): The framework from which the ONNX model is generated.
- `(options`: 'xgboost', 'sklearn')
@@ -38,7 +38,7 @@ Create ONNX model with Hummingbird convert method.
______________________________________________________________________
-
+
## function `workaround_squeeze_node_xgboost`
@@ -56,7 +56,7 @@ FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/2778 The squeeze o
______________________________________________________________________
-
+
## function `assert_add_node_and_constant_in_xgboost_regressor_graph`
@@ -72,7 +72,7 @@ Assert if an Add node with a specific constant exists in the ONNX graph.
______________________________________________________________________
-
+
## function `add_transpose_after_last_node`
@@ -92,7 +92,7 @@ Add transpose after last node.
______________________________________________________________________
-
+
## function `preprocess_tree_predictions`
@@ -116,7 +116,7 @@ Apply post-processing from the graph.
______________________________________________________________________
-
+
## function `tree_onnx_graph_preprocessing`
@@ -141,7 +141,7 @@ Apply pre-processing onto the ONNX graph.
______________________________________________________________________
-
+
## function `tree_values_preprocessing`
@@ -168,13 +168,13 @@ Pre-process tree values.
______________________________________________________________________
-
+
## function `tree_to_numpy`
```python
tree_to_numpy(
- model: Callable,
+ model: BaseEstimator,
x: ndarray,
framework: str,
use_rounding: bool = True,
@@ -198,3 +198,40 @@ Convert the tree inference to a numpy functions using Hummingbird.
**Returns:**
- `Tuple[Callable, List[QuantizedArray], onnx.ModelProto]`: A tuple with a function that takes a numpy array and returns a numpy array, QuantizedArray object to quantize and de-quantize the output of the tree, and the ONNX model.
+
+______________________________________________________________________
+
+
+
+## function `onnx_fp32_model_to_quantized_model`
+
+```python
+onnx_fp32_model_to_quantized_model(
+ onnx_model: ModelProto,
+ n_bits: int,
+ framework: str,
+ expected_number_of_outputs: int,
+ n_features: int,
+ model_inputs: Optional[ndarray] = None
+)
+```
+
+Build a FHE-compliant onnx-model using a fitted scikit-learn model.
+
+**Args:**
+
+- `onnx_model` (onnx.ModelProto): The fitted scikit-learn as a Hummingbird onnx model to convert
+- `n_bits` (int): Number of bits to quantize the model. If an int is passed for n_bits, the value will be used for quantizing inputs and weights. If a dict is passed, then it should contain "op_inputs" and "op_weights" as keys with corresponding number of quantization bits so that:
+ \- op_inputs : number of bits to quantize the input values
+ \- op_weights: number of bits to quantize the learned parameters
+- `framework` (str): either sklearn or xgboost
+- `expected_number_of_outputs` (int): expected number of outputs
+- `n_features` (int): number of features as inputs of the model
+- `model_inputs` (Optional\[numpy.ndarray\]): optional dataset to use for quantization
+
+**Returns:**
+
+- `onnx.ModelProto`: The converted onnx model
+- `Optional[Tuple[int, int]]`: Least significant bits to remove
+- `list[UniformQuantizer]`: inputs quantizers
+- `list[UniformQuantizer]`: outputs quantizers
diff --git a/docs/references/api/concrete.ml.sklearn.xgb.md b/docs/references/api/concrete.ml.sklearn.xgb.md
index 540c25101..9723b0d85 100644
--- a/docs/references/api/concrete.ml.sklearn.xgb.md
+++ b/docs/references/api/concrete.ml.sklearn.xgb.md
@@ -8,7 +8,7 @@ Implements XGBoost models.
______________________________________________________________________
-
+
## class `XGBClassifier`
@@ -16,7 +16,7 @@ Implements the XGBoost classifier.
See https://xgboost.readthedocs.io/en/stable/python/python_api.html#module-xgboost.sklearn for more information about the parameters used.
-
+
### method `__init__`
@@ -44,14 +44,24 @@ __init__(
missing: float = nan,
num_parallel_tree: Optional[int] = None,
monotone_constraints: Optional[Dict[str, int], str] = None,
- interaction_constraints: Optional[str, List[Tuple[str]]] = None,
+ interaction_constraints: Optional[str, Sequence[Sequence[str]]] = None,
importance_type: Optional[str] = None,
gpu_id: Optional[int] = None,
validate_parameters: Optional[bool] = None,
predictor: Optional[str] = None,
enable_categorical: bool = False,
+ use_label_encoder: bool = False,
random_state: Optional[int] = None,
- verbosity: Optional[int] = None
+ verbosity: Optional[int] = None,
+ max_bin: Optional[int] = None,
+ callbacks: Optional[List[TrainingCallback]] = None,
+ early_stopping_rounds: Optional[int] = None,
+ max_leaves: Optional[int] = None,
+ eval_metric: Optional[str, List[str], Callable] = None,
+ max_cat_to_onehot: Optional[int] = None,
+ grow_policy: Optional[str] = None,
+ sampling_method: Optional[str] = None,
+ **kwargs
)
```
@@ -125,7 +135,7 @@ Using this attribute is deprecated.
______________________________________________________________________
-
+
### method `dump_dict`
@@ -135,7 +145,7 @@ dump_dict() → Dict[str, Any]
______________________________________________________________________
-
+
### classmethod `load_dict`
@@ -145,7 +155,7 @@ load_dict(metadata: Dict)
______________________________________________________________________
-
+
## class `XGBRegressor`
@@ -153,7 +163,7 @@ Implements the XGBoost regressor.
See https://xgboost.readthedocs.io/en/stable/python/python_api.html#module-xgboost.sklearn for more information about the parameters used.
-
+
### method `__init__`
@@ -162,7 +172,7 @@ __init__(
n_bits: Union[int, Dict[str, int]] = 6,
max_depth: Optional[int] = 3,
learning_rate: Optional[float] = None,
- n_estimators: Optional[int] = 20,
+ n_estimators: int = 20,
objective: Optional[str] = 'reg:squarederror',
booster: Optional[str] = None,
tree_method: Optional[str] = None,
@@ -181,14 +191,23 @@ __init__(
missing: float = nan,
num_parallel_tree: Optional[int] = None,
monotone_constraints: Optional[Dict[str, int], str] = None,
- interaction_constraints: Optional[str, List[Tuple[str]]] = None,
+ interaction_constraints: Optional[str, Sequence[Sequence[str]]] = None,
importance_type: Optional[str] = None,
gpu_id: Optional[int] = None,
validate_parameters: Optional[bool] = None,
predictor: Optional[str] = None,
enable_categorical: bool = False,
- random_state: Optional[int] = None,
- verbosity: Optional[int] = None
+ random_state: Optional[RandomState, int] = None,
+ verbosity: Optional[int] = None,
+ eval_metric: Optional[str, List[str], Callable] = None,
+ sampling_method: Optional[str] = None,
+ max_leaves: Optional[int] = None,
+ max_bin: Optional[int] = None,
+ max_cat_to_onehot: Optional[int] = None,
+ grow_policy: Optional[str] = None,
+ callbacks: Optional[List[TrainingCallback]] = None,
+ early_stopping_rounds: Optional[int] = None,
+ **kwargs: Any
)
```
@@ -238,7 +257,7 @@ Is None if the model is not fitted.
______________________________________________________________________
-
+
### method `dump_dict`
@@ -248,7 +267,7 @@ dump_dict() → Dict[str, Any]
______________________________________________________________________
-
+
### method `fit`
@@ -258,7 +277,7 @@ fit(X, y, *args, **kwargs) → Any
______________________________________________________________________
-
+
### classmethod `load_dict`
@@ -268,7 +287,7 @@ load_dict(metadata: Dict)
______________________________________________________________________
-
+
### method `post_processing`
diff --git a/docs/references/api/concrete.ml.torch.compile.md b/docs/references/api/concrete.ml.torch.compile.md
index 3cb25f6d1..41d1d19ef 100644
--- a/docs/references/api/concrete.ml.torch.compile.md
+++ b/docs/references/api/concrete.ml.torch.compile.md
@@ -91,7 +91,7 @@ Take a model in torch or ONNX, turn it to numpy, quantize its inputs / weights /
______________________________________________________________________
-
+
## function `compile_torch_model`
@@ -141,7 +141,7 @@ Take a model in torch, turn it to numpy, quantize its inputs / weights / outputs
______________________________________________________________________
-
+
## function `compile_onnx_model`
@@ -191,7 +191,7 @@ Take a model in torch, turn it to numpy, quantize its inputs / weights / outputs
______________________________________________________________________
-
+
## function `compile_brevitas_qat_model`
diff --git a/docs/references/api/concrete.ml.torch.hybrid_model.md b/docs/references/api/concrete.ml.torch.hybrid_model.md
index 5d6afe1de..da91e7bb2 100644
--- a/docs/references/api/concrete.ml.torch.hybrid_model.md
+++ b/docs/references/api/concrete.ml.torch.hybrid_model.md
@@ -102,7 +102,7 @@ __init__(
______________________________________________________________________
-
+
### method `forward`
@@ -157,7 +157,7 @@ Set the clients keys.
______________________________________________________________________
-
+
### method `remote_call`
diff --git a/pyproject.toml b/pyproject.toml
index 180352c74..0184e5169 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "concrete-ml"
-version = "1.5.0"
+version = "1.6.0-rc0"
description = "Concrete ML is an open-source set of tools which aims to simplify the use of fully homomorphic encryption (FHE) for data scientists."
license = "BSD-3-Clause-Clear"
authors = [
diff --git a/src/concrete/ml/version.py b/src/concrete/ml/version.py
index 97a575f9e..a9a240dd4 100644
--- a/src/concrete/ml/version.py
+++ b/src/concrete/ml/version.py
@@ -1,4 +1,4 @@
"""File to manage the version of the package."""
# Auto-generated by "make set_version" do not modify
-__version__ = "1.5.0"
+__version__ = "1.6.0-rc0"