Skip to content

Commit

Permalink
todo:
Browse files Browse the repository at this point in the history
- we need a new post processing for the training (not related to inference)
- same for pre processing
- dequantize assume some reshaping which is wrong.
  • Loading branch information
jfrery committed May 23, 2024
1 parent 4d33245 commit dd44f35
Show file tree
Hide file tree
Showing 4 changed files with 206 additions and 105 deletions.
213 changes: 147 additions & 66 deletions docs/advanced_examples/LogisticRegressionTraining.ipynb

Large diffs are not rendered by default.

54 changes: 30 additions & 24 deletions src/concrete/ml/deployment/fhe_client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,27 +123,21 @@ def run(
serialized_evaluation_keys (bytes): the serialized evaluation keys
Returns:
bytes: the result of the model
Union[bytes, Tuple[bytes, ...]]: the result of the model
"""
assert_true(self.server is not None, "Model has not been loaded.")

if not isinstance(serialized_encrypted_quantized_data, tuple):
if isinstance(serialized_encrypted_quantized_data, bytes):
serialized_encrypted_quantized_data = (serialized_encrypted_quantized_data,)

deserialized_encrypted_quantized_data = tuple(
fhe.Value.deserialize(data) for data in serialized_encrypted_quantized_data
)
deserialized_evaluation_keys = fhe.EvaluationKeys.deserialize(serialized_evaluation_keys)
result = self.server.run(
*deserialized_encrypted_quantized_data, evaluation_keys=deserialized_evaluation_keys
)
deserialized_data = tuple(fhe.Value.deserialize(data) for data in serialized_encrypted_quantized_data)
deserialized_keys = fhe.EvaluationKeys.deserialize(serialized_evaluation_keys)

result = self.server.run(*deserialized_data, evaluation_keys=deserialized_keys)

serialized_result = (
tuple(res.serialize() for res in result)
if isinstance(result, tuple)
else result.serialize()
)
return serialized_result
if isinstance(result, tuple):
return tuple(res.serialize() for res in result)
return result.serialize()


class FHEModelDev:
Expand Down Expand Up @@ -368,40 +362,48 @@ def quantize_encrypt_serialize(
serialized_enc_qx = tuple(e.serialize() for e in enc_qx)
return serialized_enc_qx

def deserialize_decrypt(self, serialized_encrypted_quantized_result: bytes) -> numpy.ndarray:
def deserialize_decrypt(self, serialized_encrypted_quantized_result: Union[bytes, Tuple[bytes, ...]]) -> numpy.ndarray:
"""Deserialize and decrypt the values.
Args:
serialized_encrypted_quantized_result (bytes): the serialized, encrypted
serialized_encrypted_quantized_result (Union[bytes, Tuple[bytes, ...]]): the serialized, encrypted
and quantized result
Returns:
numpy.ndarray: the decrypted and deserialized values
"""
# Ensure the input is a tuple
if isinstance(serialized_encrypted_quantized_result, bytes):
serialized_encrypted_quantized_result = (serialized_encrypted_quantized_result,)

# Deserialize the encrypted values
deserialized_encrypted_quantized_result = fhe.Value.deserialize(
serialized_encrypted_quantized_result
deserialized_encrypted_quantized_result = tuple(
fhe.Value.deserialize(data) for data in serialized_encrypted_quantized_result
)

# Decrypt the values
deserialized_decrypted_quantized_result = self.client.decrypt(
deserialized_encrypted_quantized_result
*deserialized_encrypted_quantized_result
)
assert isinstance(deserialized_decrypted_quantized_result, numpy.ndarray)
assert isinstance(deserialized_decrypted_quantized_result, (numpy.ndarray, tuple))
return deserialized_decrypted_quantized_result

def deserialize_decrypt_dequantize(
self, serialized_encrypted_quantized_result: bytes
self, serialized_encrypted_quantized_result: Union[bytes, Tuple[bytes, ...]]
) -> numpy.ndarray:
"""Deserialize, decrypt and de-quantize the values.
Args:
serialized_encrypted_quantized_result (bytes): the serialized, encrypted
and quantized result
serialized_encrypted_quantized_result (Union[bytes, Tuple[bytes, ...]]): the
serialized, encrypted and quantized result
Returns:
numpy.ndarray: the decrypted (de-quantized) values
"""
# Ensure the input is a tuple
if isinstance(serialized_encrypted_quantized_result, bytes):
serialized_encrypted_quantized_result = (serialized_encrypted_quantized_result,)

# Decrypt and deserialize the values
deserialized_decrypted_quantized_result = self.deserialize_decrypt(
serialized_encrypted_quantized_result
Expand All @@ -417,4 +419,8 @@ def deserialize_decrypt_dequantize(
deserialized_decrypted_dequantized_result
)

# If the result is a single element tuple, return the element itself
if len(deserialized_decrypted_dequantized_result) == 1:
return deserialized_decrypted_dequantized_result[0]

return deserialized_decrypted_dequantized_result
30 changes: 19 additions & 11 deletions src/concrete/ml/sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from abc import ABC, abstractmethod
from functools import partial
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Set, TextIO, Type, Union
from typing import Any, Callable, Dict, List, Optional, Set, TextIO, Tuple, Type, Union

import brevitas.nn as qnn
import numpy
Expand Down Expand Up @@ -1704,22 +1704,30 @@ def _quantize_model(self, X):

return self

def quantize_input(self, X: numpy.ndarray) -> numpy.ndarray:
def quantize_input(
self, X: Union[numpy.ndarray, Tuple[numpy.ndarray, ...]]
) -> Union[numpy.ndarray, Tuple[numpy.ndarray, ...]]:
self.check_model_is_fitted()
q_X = self.input_quantizers[0].quant(X)

assert q_X.dtype == numpy.int64, "Inputs were not quantized to int64 values"
if isinstance(X, tuple):
q_X = tuple(self.input_quantizers[i].quant(x) for i, x in enumerate(X))
for q in q_X:
assert q.dtype == numpy.int64, "Inputs were not quantized to int64 values"
else:
q_X = self.input_quantizers[0].quant(X)
assert q_X.dtype == numpy.int64, "Inputs were not quantized to int64 values"

return q_X

def dequantize_output(self, q_y_preds: numpy.ndarray) -> numpy.ndarray:
def dequantize_output(self, q_y_preds: Union[numpy.ndarray, Tuple[numpy.ndarray, ...]]) -> Union[numpy.ndarray, Tuple[numpy.ndarray, ...]]:
self.check_model_is_fitted()

# De-quantize the output values
y_preds = self.output_quantizers[0].dequant(q_y_preds)

# If the preds have shape (n, 1), squeeze it to shape (n,) like in scikit-learn
if y_preds.ndim == 2 and y_preds.shape[1] == 1:
return y_preds.ravel()
if isinstance(q_y_preds, tuple):
y_preds = tuple(self.output_quantizers[i].dequant(q) for i, q in enumerate(q_y_preds))
else:
y_preds = self.output_quantizers[0].dequant(q_y_preds)
if y_preds.ndim == 2 and y_preds.shape[1] == 1:
y_preds = y_preds.ravel()

return y_preds

Expand Down
14 changes: 10 additions & 4 deletions tests/deployment/test_client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,22 @@
import zipfile
from pathlib import Path
from shutil import copyfile
from concrete.ml.sklearn.linear_model import SGDClassifier

import numpy
import pytest
from sklearn.exceptions import ConvergenceWarning
from torch import nn

from concrete.ml.deployment.fhe_client_server import FHEModelClient, FHEModelDev, FHEModelServer, Mode
from concrete.ml.deployment.fhe_client_server import (
FHEModelClient,
FHEModelDev,
FHEModelServer,
Mode,
)
from concrete.ml.pytest.torch_models import FCSmall
from concrete.ml.pytest.utils import MODELS_AND_DATASETS, get_model_name, instantiate_model_generic
from concrete.ml.quantization.quantized_module import QuantizedModule
from concrete.ml.sklearn.linear_model import SGDClassifier
from concrete.ml.torch.compile import compile_torch_model

# pylint: disable=too-many-statements,too-many-locals
Expand Down Expand Up @@ -352,6 +357,7 @@ def check_input_compression(model, fhe_circuit_compressed, is_torch, **compilati
f"{compressed_size} bytes (compressed) and {uncompressed_size} bytes (uncompressed)."
)


@pytest.mark.parametrize("model_class, parameters", [MODELS_AND_DATASETS])
@pytest.mark.parametrize("n_bits", [2])
@pytest.mark.parametrize(
Expand All @@ -370,7 +376,7 @@ def check_input_compression(model, fhe_circuit_compressed, is_torch, **compilati
)
def test_save_mode(model_class, parameters, n_bits, mode, should_raise, error_message, load_data):
"""Test that the save method handles valid and invalid modes correctly."""

# Only run this test for SGDClassifier
if model_class != SGDClassifier:
pytest.skip(f"Skipping test for model class {model_class}")
Expand Down Expand Up @@ -409,4 +415,4 @@ def test_save_mode(model_class, parameters, n_bits, mode, should_raise, error_me
try:
model_dev.save(mode=mode)
except ValueError:
pytest.fail("Valid mode raised ValueError unexpectedly")
pytest.fail("Valid mode raised ValueError unexpectedly")

0 comments on commit dd44f35

Please sign in to comment.