Skip to content

Commit

Permalink
chore: enable input compression through os.environ
Browse files Browse the repository at this point in the history
  • Loading branch information
RomanBredehoft committed Apr 23, 2024
1 parent f7b5122 commit 2cb854f
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 11 deletions.
5 changes: 3 additions & 2 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import hashlib
import json
import os
import random
import re
from pathlib import Path
Expand Down Expand Up @@ -153,7 +154,7 @@ def default_configuration():
insecure_key_cache_location="ConcretePythonKeyCache",
fhe_simulation=False,
fhe_execution=True,
compress_input_ciphertexts=True,
compress_input_ciphertexts=os.environ.get("USE_INPUT_COMPRESSION", "1") == "1",
)


Expand All @@ -171,7 +172,7 @@ def simulation_configuration():
insecure_key_cache_location="ConcretePythonKeyCache",
fhe_simulation=True,
fhe_execution=False,
compress_input_ciphertexts=True,
compress_input_ciphertexts=os.environ.get("USE_INPUT_COMPRESSION", "1") == "1",
)


Expand Down
6 changes: 6 additions & 0 deletions src/concrete/ml/common/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Utils that can be re-used by other pieces of code in the module."""

import enum
import os
import string
from functools import partial
from types import FunctionType
Expand Down Expand Up @@ -47,6 +48,11 @@
# should be exact compared to their Concrete ML QuantizedModule
QUANT_ROUND_LIKE_ROUND_PBS = False

# Enable input ciphertext compression
# Note: This setting is fixed and cannot be altered by users
# However, for internal testing purposes, we retain the capability to disable this feature
os.environ["USE_INPUT_COMPRESSION"] = os.environ.get("USE_INPUT_COMPRESSION", "1")


class FheMode(str, enum.Enum):
"""Enum representing the execution mode.
Expand Down
6 changes: 5 additions & 1 deletion src/concrete/ml/quantization/quantized_module.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""QuantizedModule API."""

import copy
import os
import re
from functools import partial
from typing import Any, Dict, Generator, Iterable, List, Optional, Sequence, TextIO, Tuple, Union
Expand Down Expand Up @@ -737,6 +738,9 @@ def compile(
# Find the right way to set parameters for compiler, depending on the way we want to default
p_error, global_p_error = manage_parameters_for_pbs_errors(p_error, global_p_error)

# Enable input ciphertext compression
enable_input_compression = os.environ.get("USE_INPUT_COMPRESSION", "1") == "1"

self.fhe_circuit = compiler.compile(
inputset,
configuration=configuration,
Expand All @@ -748,7 +752,7 @@ def compile(
single_precision=False,
fhe_simulation=False,
fhe_execution=True,
compress_input_ciphertexts=True,
compress_input_ciphertexts=enable_input_compression,
)

self._is_compiled = True
Expand Down
5 changes: 4 additions & 1 deletion src/concrete/ml/sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,9 @@ def compile(
f"{type(module_to_compile)}."
)

# Enable input ciphertext compression
enable_input_compression = os.environ.get("USE_INPUT_COMPRESSION", "1") == "1"

self.fhe_circuit_ = module_to_compile.compile(
inputset,
configuration=configuration,
Expand All @@ -571,7 +574,7 @@ def compile(
single_precision=False,
fhe_simulation=False,
fhe_execution=True,
compress_input_ciphertexts=True,
compress_input_ciphertexts=enable_input_compression,
)

self._is_compiled = True
Expand Down
56 changes: 50 additions & 6 deletions tests/deployment/test_client_server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Tests the deployment APIs."""

import json
import os
import tempfile
import warnings
import zipfile
Expand Down Expand Up @@ -107,8 +108,15 @@ def test_client_server_sklearn(
x_test, model, key_dir, check_array_equal, check_float_array_equal
)

compilation_kwargs = {
"X": x_train,
"configuration": default_configuration,
}

# Compile the model
fhe_circuit = model.compile(x_train, configuration=default_configuration)
fhe_circuit = model.compile(**compilation_kwargs)

check_input_compression(model, fhe_circuit, is_torch=False, **compilation_kwargs)

# Check that client and server files are properly generated
check_client_server_files(model)
Expand Down Expand Up @@ -150,12 +158,17 @@ def test_client_server_custom_model(

torch_model = FCSmall(2, nn.ReLU)

compilation_kwargs = {
"torch_inputset": x_train,
"configuration": default_configuration,
"n_bits": 2,
}

# Get the quantized module from the model and compile it
quantized_numpy_module = compile_torch_model(
torch_model,
x_train,
configuration=default_configuration,
n_bits=2,
quantized_numpy_module = compile_torch_model(torch_model, **compilation_kwargs)

check_input_compression(
torch_model, quantized_numpy_module.fhe_circuit, is_torch=True, **compilation_kwargs
)

# Check that client and server files are properly generated
Expand Down Expand Up @@ -288,3 +301,34 @@ def check_client_server_execution(

# Clean up
disk_network.cleanup()


def check_input_compression(model, fhe_circuit_compressed, is_torch, **compilation_kwargs):
"""Check that input compression properly reduces input sizes."""

# Check that input ciphertext compression is enabled
assert os.environ.get("USE_INPUT_COMPRESSION") == "1", "'USE_INPUT_COMPRESSION' is not enabled"

compressed_size = fhe_circuit_compressed.size_of_inputs

with pytest.MonkeyPatch.context() as mp_context:

# Disable input ciphertext compression
mp_context.setenv("USE_INPUT_COMPRESSION", "0")

# Check that input ciphertext compression is disabled
assert (
os.environ.get("USE_INPUT_COMPRESSION") == "0"
), "'USE_INPUT_COMPRESSION' is not disabled"

if is_torch:
fhe_circuit_uncompressed = compile_torch_model(model, **compilation_kwargs).fhe_circuit
else:
fhe_circuit_uncompressed = model.compile(**compilation_kwargs)

uncompressed_size = fhe_circuit_uncompressed.size_of_inputs

assert compressed_size < uncompressed_size, (
"Compressed input ciphertext's is not smaller than the uncompressed input ciphertext. Got "
f"{compressed_size} bytes (compressed) and {uncompressed_size} bytes (uncompressed)."
)
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def wrapper(*args, **kwargs):
enable_unsafe_features=True,
use_insecure_key_cache=True,
insecure_key_cache_location=KEYGEN_CACHE_DIR,
compress_input_ciphertexts=True,
)

print("Compiling the model.")
Expand Down

0 comments on commit 2cb854f

Please sign in to comment.