Skip to content

Commit

Permalink
chore: add requant in FHE
Browse files Browse the repository at this point in the history
  • Loading branch information
RomanBredehoft committed May 30, 2024
1 parent bedb72c commit 489fdc6
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 60 deletions.
40 changes: 20 additions & 20 deletions docs/advanced_examples/LogisticRegressionTraining.ipynb

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion src/concrete/ml/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@
# Enable input ciphertext compression
# Note: This setting is fixed and cannot be altered by users
# However, for internal testing purposes, we retain the capability to disable this feature
os.environ["USE_INPUT_COMPRESSION"] = os.environ.get("USE_INPUT_COMPRESSION", "1")
# TODO: remove this once the nightly with the fix is integrated
os.environ["USE_INPUT_COMPRESSION"] = os.environ.get("USE_INPUT_COMPRESSION", "0")


class FheMode(str, enum.Enum):
Expand Down
11 changes: 7 additions & 4 deletions src/concrete/ml/quantization/post_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,15 +681,17 @@ def _quantize_layers(self, *input_calibration_data: numpy.ndarray):
node_results[output_name] = node_output[0]
constants.add(output_name)

def quantize_module(self, *calibration_data: numpy.ndarray) -> QuantizedModule:
def quantize_module(
self, *calibration_data: numpy.ndarray, composition_mapping: Optional[Dict] = None
) -> QuantizedModule:
"""Quantize numpy module.
Following https://arxiv.org/abs/1712.05877 guidelines.
Args:
*calibration_data (numpy.ndarray): Data that will be used to compute the bounds,
scales and zero point values for every quantized
object.
calibration_data (numpy.ndarray): Data that will be used to compute the bounds,
scales and zero point values for every quantized object.
force_output_requant (bool):
Returns:
QuantizedModule: Quantized numpy module
Expand All @@ -709,6 +711,7 @@ def quantize_module(self, *calibration_data: numpy.ndarray) -> QuantizedModule:
),
quant_layers_dict=self.quant_ops_dict,
onnx_model=self.numpy_model.onnx_model,
composition_mapping=composition_mapping,
)

adapter = PowerOfTwoScalingRoundPBSAdapter(quantized_module)
Expand Down
24 changes: 19 additions & 5 deletions src/concrete/ml/quantization/quantized_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def __init__(
ordered_module_output_names: Optional[Iterable[str]] = None,
quant_layers_dict: Optional[Dict[str, Tuple[Tuple[str, ...], QuantizedOp]]] = None,
onnx_model: Optional[onnx.ModelProto] = None,
composition_mapping: Optional[Dict] = None,
):

all_or_none_params = [
Expand Down Expand Up @@ -139,6 +140,9 @@ def __init__(
else:
self.output_quantizers = []

# TODO: add check for inputs and outputs
self._composition_mapping = composition_mapping

# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4127
def set_reduce_sum_copy(self):
"""Set reduce sum to copy or not the inputs.
Expand Down Expand Up @@ -272,7 +276,7 @@ def _set_output_quantizers(self) -> List[UniformQuantizer]:
Returns:
List[UniformQuantizer]: List of output quantizers.
"""
output_layers = (
output_layers = list(
self.quant_layers_dict[output_name][1]
for output_name in self.ordered_module_output_names
)
Expand Down Expand Up @@ -483,12 +487,22 @@ def _clear_forward(
# The output of a graph must be a QuantizedArray
assert all(isinstance(elt, QuantizedArray) for elt in output_quantized_arrays)

results = tuple(
q_results = tuple(
elt.qvalues for elt in output_quantized_arrays if isinstance(elt, QuantizedArray)
)
if len(results) == 1:
return results[0]
return results

if self._composition_mapping is not None:
q_results = tuple(
self.input_quantizers[input_i].quant(
self.output_quantizers[output_i].dequant(q_results[output_i])
)
for output_i, input_i in self._composition_mapping.items()
)

if len(q_results) == 1:
return q_results[0]

return q_results

def _fhe_forward(
self, *q_x: numpy.ndarray, simulate: bool = True
Expand Down
29 changes: 22 additions & 7 deletions src/concrete/ml/sklearn/linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,9 +358,10 @@ def _get_training_quantized_module(
# Enable the underlying FHE circuit to be composed with itself
# This feature is used in order to be able to iterate in the clear n times without having
# to encrypt/decrypt the weight/bias values between each loop
# configuration = Configuration(composable=True, detect_overflow_in_simulation=False)
configuration = Configuration(composable=True)

composition_mapping = {0: 2, 1: 3}

# Compile the model using the compile set
if self.verbose:
print("Compiling training circuit ...")
Expand All @@ -374,6 +375,7 @@ def _get_training_quantized_module(
p_error=self.training_p_error,
configuration=configuration,
reduce_sum_copy=True,
composition_mapping=composition_mapping,
)
end = time.time()

Expand Down Expand Up @@ -520,7 +522,7 @@ def _fit_encrypted(

# Initialize the weight values with the given ones if some are provided
if coef_init is not None:
weights = coef_init
weights = coef_init.reshape(weight_shape)

# Else, if warm start is activated or this is a partial fit, use some already computed
# weight values have if there are some
Expand All @@ -540,7 +542,7 @@ def _fit_encrypted(

# Initialize the bias values with the given ones if some are provided
if intercept_init is not None:
bias = intercept_init
bias = intercept_init.reshape(bias_shape)

# Else, if warm start is activated or this is a partial fit, use some already computed
# bias values have if there are some
Expand Down Expand Up @@ -627,15 +629,28 @@ def _fit_encrypted(
X_batches_enc[iteration_step],
y_batches_enc[iteration_step],
)

# Train the model over one iteration
inference_start = time.time()

# If the training is done in FHE, execute the underlying FHE circuit directly on the
# encrypted values
weights_enc, bias_enc = self.training_quantized_module.quantized_forward(
X_batch_enc_i, y_batch_enc_i, weights_enc, bias_enc, fhe=fhe
)
if fhe == "execute":
weights_enc, bias_enc = self.training_quantized_module.fhe_circuit.run(
X_batch_enc_i,
y_batch_enc_i,
weights_enc,
bias_enc,
)

# Else, use the quantized module on the quantized values (works for both quantized
# clear and FHE simulation modes). It is important to note that 'quantized_forward'
# with 'fhe="execute"' is executing Concrete's 'encrypt_run_decrypt' method, as opposed
# to the 'run' method right above. We thus need to separate these cases since values
# are already encrypted here.
else:
weights_enc, bias_enc = self.training_quantized_module.quantized_forward(
X_batch_enc_i, y_batch_enc_i, weights_enc, bias_enc, fhe=fhe
)

if self.verbose:
print(
Expand Down
21 changes: 19 additions & 2 deletions src/concrete/ml/torch/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def build_quantized_module(
n_bits: Union[int, Dict[str, int]] = MAX_BITWIDTH_BACKWARD_COMPATIBLE,
rounding_threshold_bits: Union[None, int, Dict[str, Union[str, int]]] = None,
reduce_sum_copy=False,
composition_mapping: Optional[Dict] = None,
) -> QuantizedModule:
"""Build a quantized module from a Torch or ONNX model.
Expand Down Expand Up @@ -124,10 +125,14 @@ def build_quantized_module(
# FIXME: mismatch here. We traced with dummy_input_for_tracing which made some operator
# only work over shape of (1, ., .). For example, some reshape have newshape hardcoded based
# on the inputset we sent in the NumpyModule.
quantized_module = post_training_quant.quantize_module(*inputset_as_numpy_tuple)
quantized_module = post_training_quant.quantize_module(
*inputset_as_numpy_tuple, composition_mapping=composition_mapping
)

# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4127
if reduce_sum_copy:
quantized_module.set_reduce_sum_copy()

return quantized_module


Expand All @@ -145,7 +150,8 @@ def _compile_torch_or_onnx_model(
global_p_error: Optional[float] = None,
verbose: bool = False,
inputs_encryption_status: Optional[Sequence[str]] = None,
reduce_sum_copy=False,
reduce_sum_copy: bool = False,
composition_mapping: Optional[Dict] = None,
) -> QuantizedModule:
"""Compile a torch module or ONNX into an FHE equivalent.
Expand Down Expand Up @@ -191,6 +197,12 @@ def _compile_torch_or_onnx_model(
convert_torch_tensor_or_numpy_array_to_numpy_array(val) for val in to_tuple(torch_inputset)
)

if composition_mapping is not None and not configuration.composable:
raise ValueError(
"Please enable the composition feature in order to be able to take the mapping between "
"inputs and output quantizers into account."
)

# Build the quantized module
quantized_module = build_quantized_module(
model=model,
Expand All @@ -199,6 +211,7 @@ def _compile_torch_or_onnx_model(
n_bits=n_bits,
rounding_threshold_bits=rounding_threshold_bits,
reduce_sum_copy=reduce_sum_copy,
composition_mapping=composition_mapping,
)

# Check that p_error or global_p_error is not set in both the configuration and in the direct
Expand Down Expand Up @@ -248,6 +261,7 @@ def compile_torch_model(
verbose: bool = False,
inputs_encryption_status: Optional[Sequence[str]] = None,
reduce_sum_copy: bool = False,
composition_mapping: Optional[Dict] = None,
) -> QuantizedModule:
"""Compile a torch module into an FHE equivalent.
Expand Down Expand Up @@ -314,9 +328,11 @@ def compile_torch_model(
verbose=verbose,
inputs_encryption_status=inputs_encryption_status,
reduce_sum_copy=reduce_sum_copy,
composition_mapping=composition_mapping,
)


# TODO: add 'composition_mapping' here as well
# pylint: disable-next=too-many-arguments
def compile_onnx_model(
onnx_model: onnx.ModelProto,
Expand Down Expand Up @@ -397,6 +413,7 @@ def compile_onnx_model(
)


# TODO: add 'composition_mapping' here as well ?
# pylint: disable-next=too-many-arguments
def compile_brevitas_qat_model(
torch_model: torch.nn.Module,
Expand Down
24 changes: 3 additions & 21 deletions tests/sklearn/test_fhe_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,24 +530,6 @@ def test_encrypted_fit_coherence(
assert array_allclose_and_same_shape(y_pred_proba_simulated, y_pred_proba_disable)
assert array_allclose_and_same_shape(y_pred_class_simulated, y_pred_class_disable)

# Define early break parameters, with a very high tolerance
early_break_kwargs = {"early_stopping": True, "tol": 1e100}

# We don't have any way to properly test early break, we therefore disable the accuracy check
# in order to avoid flaky issues
check_encrypted_fit(
x,
y,
n_bits,
random_state,
parameters_range,
max_iter,
fit_intercept,
check_accuracy=None,
fhe="simulate",
init_kwargs=early_break_kwargs,
)

weights_partial, bias_partial, y_pred_proba_partial, y_pred_class_partial, _ = (
check_encrypted_fit(
x,
Expand Down Expand Up @@ -594,7 +576,7 @@ def test_encrypted_fit_coherence(

# Fit the model for max_iter // 2 iterations and retrieved the weight/bias values, as well as
# the RNG object
weights_coef_init, bias_coef_init, _, _, rng_coef_init = check_encrypted_fit(
weights_coef_init_partial, bias_coef_init_partial, _, _, rng_coef_init = check_encrypted_fit(
x,
y,
n_bits,
Expand All @@ -610,8 +592,8 @@ def test_encrypted_fit_coherence(

# Define coef parameters
coef_init_fit_kwargs = {
"coef_init": weights_coef_init,
"intercept_init": bias_coef_init,
"coef_init": weights_coef_init_partial,
"intercept_init": bias_coef_init_partial,
}

# Fit the model for the remaining iterations starting at the previous weight/bias values. It is
Expand Down

0 comments on commit 489fdc6

Please sign in to comment.