Skip to content

Commit

Permalink
chore: change api name from mode_training to mode = "training" + add …
Browse files Browse the repository at this point in the history
…test
  • Loading branch information
jfrery committed May 23, 2024
1 parent abbb9df commit 4d33245
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 140 deletions.
147 changes: 36 additions & 111 deletions docs/advanced_examples/LogisticRegressionTraining.ipynb

Large diffs are not rendered by default.

3 changes: 0 additions & 3 deletions docs/advanced_examples/fhe_training_sgd/versions.json

This file was deleted.

71 changes: 46 additions & 25 deletions src/concrete/ml/deployment/fhe_client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import json
import sys
import zipfile
from enum import Enum
from pathlib import Path
from typing import Any, Optional
from typing import Any, Optional, Tuple, Union

import numpy

from concrete import fhe
from concrete.ml.quantization.quantized_module import QuantizedModule

from ..common.debugging.custom_assert import assert_true
from ..common.serialization.dumpers import dump
Expand All @@ -25,6 +27,11 @@
from importlib_metadata import version


class Mode(Enum):
INFERENCE = "inference"
TRAINING = "training"


def check_concrete_versions(zip_path: Path):
"""Check that current versions match the ones used in development.
Expand Down Expand Up @@ -105,13 +112,13 @@ def load(self):

def run(
self,
serialized_encrypted_quantized_data: bytes,
serialized_encrypted_quantized_data: Union[bytes, Tuple[bytes, ...]],
serialized_evaluation_keys: bytes,
) -> bytes:
"""Run the model on the server over encrypted data.
Args:
serialized_encrypted_quantized_data (bytes): the encrypted, quantized
serialized_encrypted_quantized_data (Union[bytes, Tuple[bytes, ...]]): the encrypted, quantized
and serialized data
serialized_evaluation_keys (bytes): the serialized evaluation keys
Expand All @@ -120,14 +127,22 @@ def run(
"""
assert_true(self.server is not None, "Model has not been loaded.")

deserialized_encrypted_quantized_data = fhe.Value.deserialize(
serialized_encrypted_quantized_data
if not isinstance(serialized_encrypted_quantized_data, tuple):
serialized_encrypted_quantized_data = (serialized_encrypted_quantized_data,)

deserialized_encrypted_quantized_data = tuple(
fhe.Value.deserialize(data) for data in serialized_encrypted_quantized_data
)
deserialized_evaluation_keys = fhe.EvaluationKeys.deserialize(serialized_evaluation_keys)
result = self.server.run(
deserialized_encrypted_quantized_data, evaluation_keys=deserialized_evaluation_keys
*deserialized_encrypted_quantized_data, evaluation_keys=deserialized_evaluation_keys
)

serialized_result = (
tuple(res.serialize() for res in result)
if isinstance(result, tuple)
else result.serialize()
)
serialized_result = result.serialize()
return serialized_result


Expand Down Expand Up @@ -178,18 +193,25 @@ def _export_model_to_json(self, is_training: bool = False) -> Path:

return json_path

def save(self, via_mlir: bool = False, training_mode: bool = False):
def save(self, mode: Mode = Mode.INFERENCE, via_mlir: bool = False):
"""Export all needed artifacts for the client and server.
Arguments:
mode (Mode): the mode to save the FHE circuit, either "inference" or "training".
via_mlir (bool): serialize with `via_mlir` option from Concrete-Python.
training_mode (bool): if True, save the training part of the FHE circuit.
Raises:
Exception: path_dir is not empty or training module does not exist
"""

if isinstance(mode, str):
mode_upper = mode.upper()
if mode_upper not in Mode.__members__:
raise ValueError("Mode must be either 'inference' or 'training'")
mode = Mode[mode_upper]

# Get fhe_circuit based on the mode
if training_mode:
if mode == Mode.TRAINING:

# Check that training FHE circuit exists
assert_true(
Expand All @@ -212,7 +234,7 @@ def save(self, via_mlir: bool = False, training_mode: bool = False):
)

# Export the quantizers
json_path = self._export_model_to_json()
json_path = self._export_model_to_json(is_training=(mode == Mode.TRAINING))

# Save the circuit for the server
path_circuit_server = Path(self.path_dir).joinpath("server.zip")
Expand Down Expand Up @@ -285,16 +307,9 @@ def load(self): # pylint: disable=no-value-for-parameter
# Initialize the model
self.model = serialized_processing["model_type"]()

if serialized_processing["is_training"]:
self.model.training_quantized_module.input_quantizers = serialized_processing[
"input_quantizers"
]
self.model.training_quantized_module.output_quantizers = serialized_processing[
"output_quantizers"
]
else:
self.model.input_quantizers = serialized_processing["input_quantizers"]
self.model.output_quantizers = serialized_processing["output_quantizers"]
# Load the quantizers
self.model.input_quantizers = serialized_processing["input_quantizers"]
self.model.output_quantizers = serialized_processing["output_quantizers"]

# Load the `_is_fitted` private attribute for built-in models
if "is_fitted" in serialized_processing:
Expand Down Expand Up @@ -328,23 +343,29 @@ def get_serialized_evaluation_keys(self) -> bytes:

return self.client.evaluation_keys.serialize()

def quantize_encrypt_serialize(self, x: numpy.ndarray) -> bytes:
def quantize_encrypt_serialize(
self, x: Union[numpy.ndarray, Tuple[numpy.ndarray, ...]]
) -> bytes:
"""Quantize, encrypt and serialize the values.
Args:
x (numpy.ndarray): the values to quantize, encrypt and serialize
x (Union[numpy.ndarray, Tuple[numpy.ndarray, ...]]): the values to quantize, encrypt and serialize
Returns:
bytes: the quantized, encrypted and serialized values
"""
# Quantize the values
quantized_x = self.model.quantize_input(x)

# To tuple if not tuple
if not isinstance(quantized_x, tuple):
quantized_x = (quantized_x,)

# Encrypt the values
enc_qx = self.client.encrypt(quantized_x)
enc_qx = self.client.encrypt(*quantized_x)

# Serialize the encrypted values to be sent to the server
serialized_enc_qx = enc_qx.serialize()
serialized_enc_qx = tuple(e.serialize() for e in enc_qx)
return serialized_enc_qx

def deserialize_decrypt(self, serialized_encrypted_quantized_result: bytes) -> numpy.ndarray:
Expand Down
62 changes: 61 additions & 1 deletion tests/deployment/test_client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
import zipfile
from pathlib import Path
from shutil import copyfile
from concrete.ml.sklearn.linear_model import SGDClassifier

import numpy
import pytest
from sklearn.exceptions import ConvergenceWarning
from torch import nn

from concrete.ml.deployment.fhe_client_server import FHEModelClient, FHEModelDev, FHEModelServer
from concrete.ml.deployment.fhe_client_server import FHEModelClient, FHEModelDev, FHEModelServer, Mode
from concrete.ml.pytest.torch_models import FCSmall
from concrete.ml.pytest.utils import MODELS_AND_DATASETS, get_model_name, instantiate_model_generic
from concrete.ml.quantization.quantized_module import QuantizedModule
Expand Down Expand Up @@ -350,3 +351,62 @@ def check_input_compression(model, fhe_circuit_compressed, is_torch, **compilati
"Compressed input ciphertext's is not smaller than the uncompressed input ciphertext. Got "
f"{compressed_size} bytes (compressed) and {uncompressed_size} bytes (uncompressed)."
)

@pytest.mark.parametrize("model_class, parameters", [MODELS_AND_DATASETS])
@pytest.mark.parametrize("n_bits", [2])
@pytest.mark.parametrize(
"mode, should_raise, error_message",
[
("invalid_mode", True, "Mode must be either 'inference' or 'training'"),
("INVALID_MODE", True, "Mode must be either 'inference' or 'training'"),
("train", True, "Mode must be either 'inference' or 'training'"),
("", True, "Mode must be either 'inference' or 'training'"),
(None, True, "Mode must be either 'inference' or 'training'"),
("inference", False, None),
("training", True, "Training FHE circuit does not exist."),
(Mode.INFERENCE, False, None),
(Mode.TRAINING, True, "Training FHE circuit does not exist."),
],
)
def test_save_mode(model_class, parameters, n_bits, mode, should_raise, error_message, load_data):
"""Test that the save method handles valid and invalid modes correctly."""

# Only run this test for SGDClassifier
if model_class != SGDClassifier:
pytest.skip(f"Skipping test for model class {model_class}")

# Generate random data
x, y = load_data(model_class, **parameters)

x_train = x[:-1]
y_train = y[:-1]

# Instantiate the model
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
model = instantiate_model_generic(model_class, n_bits=n_bits)

# Set fit_encrypted to False to trigger the specific error
model.fit_encrypted = False

# Fit the model
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=ConvergenceWarning)
warnings.simplefilter("ignore", category=UserWarning)
model.fit(x_train, y_train)

# Compile
model.compile(X=x_train)

# Create FHEModelDev instance
with tempfile.TemporaryDirectory() as temp_dir:
model_dev = FHEModelDev(path_dir=temp_dir, model=model)

if should_raise:
with pytest.raises(ValueError, match=error_message):
model_dev.save(mode=mode)
else:
try:
model_dev.save(mode=mode)
except ValueError:
pytest.fail("Valid mode raised ValueError unexpectedly")

0 comments on commit 4d33245

Please sign in to comment.