Skip to content

Commit

Permalink
chore: address reviews
Browse files Browse the repository at this point in the history
  • Loading branch information
fd0r committed Sep 20, 2023
1 parent 0b6ef48 commit c964cf2
Show file tree
Hide file tree
Showing 10 changed files with 93 additions and 157 deletions.
51 changes: 1 addition & 50 deletions src/concrete/ml/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,53 +183,6 @@ def get_equivalent_numpy_forward_from_onnx(
# Optimize ONNX graph
# List of all currently supported onnx optimizer passes
# From https://github.com/onnx/optimizer/blob/master/onnxoptimizer/pass_registry.h
# onnx_passes = [
# 'adjust_add',
# 'rename_input_output',
# 'set_unique_name_for_nodes',
# 'nop',
# 'eliminate_nop_cast',
# 'eliminate_nop_dropout',
# 'eliminate_nop_flatten',
# 'extract_constant_to_initializer',
# 'eliminate_if_with_const_cond',
# 'eliminate_nop_monotone_argmax',
# 'eliminate_nop_pad',
# 'eliminate_nop_concat',
# 'eliminate_nop_split',
# 'eliminate_nop_expand',
# 'eliminate_shape_gather',
# 'eliminate_slice_after_shape',
# 'eliminate_nop_transpose',
# 'fuse_add_bias_into_conv',
# 'fuse_bn_into_conv',
# 'fuse_consecutive_concats',
# 'fuse_consecutive_log_softmax',
# 'fuse_consecutive_reduce_unsqueeze',
# 'fuse_consecutive_squeezes',
# 'fuse_consecutive_transposes',
# 'fuse_matmul_add_bias_into_gemm',
# 'fuse_pad_into_conv',
# 'fuse_pad_into_pool',
# 'fuse_transpose_into_gemm',
# 'replace_einsum_with_matmul',
# 'lift_lexical_references',
# 'split_init',
# 'split_predict',
# 'fuse_concat_into_reshape',
# 'eliminate_nop_reshape',
# 'eliminate_nop_with_unit',
# 'eliminate_common_subexpression',
# 'fuse_qkv',
# 'fuse_consecutive_unsqueezes',
# 'eliminate_deadend',
# 'eliminate_identity',
# 'eliminate_shape_op',
# 'fuse_consecutive_slices',
# 'eliminate_unused_initializer',
# 'eliminate_duplicate_initializer',
# 'adjust_slice_and_matmul'
# ]
onnx_passes = [
"fuse_matmul_add_bias_into_gemm",
"eliminate_nop_pad",
Expand All @@ -243,9 +196,7 @@ def get_equivalent_numpy_forward_from_onnx(
# ONNX optimizer does not optimize Mat-Mult + Bias pattern into GEMM if the input isn't a matrix
# We manually do the optimization for this case
equivalent_onnx_model = fuse_matmul_bias_to_gemm(equivalent_onnx_model)
with open("debug.onnx", "wb") as file:
file.write(equivalent_onnx_model.SerializeToString())
# checker.check_model(equivalent_onnx_model)
checker.check_model(equivalent_onnx_model)

# Check supported operators
required_onnx_operators = set(get_op_type(node) for node in equivalent_onnx_model.graph.node)
Expand Down
4 changes: 1 addition & 3 deletions src/concrete/ml/quantization/post_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,9 +788,7 @@ def _process_initializer(self, n_bits: int, values: numpy.ndarray):
QuantizedArray: a quantized tensor with integer values on n_bits bits
"""

if isinstance(values, numpy.ndarray) and numpy.issubdtype(
values.dtype, numpy.integer
): # pragma:no cover
if isinstance(values, numpy.ndarray) and numpy.issubdtype(values.dtype, numpy.integer):
return values.view(RawOpOutput)

assert isinstance(values, (numpy.ndarray, float))
Expand Down
5 changes: 4 additions & 1 deletion src/concrete/ml/torch/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,10 @@ def build_quantized_module(
convert_torch_tensor_or_numpy_array_to_numpy_array(val) for val in to_tuple(torch_inputset)
)

# No batch dimension (i.e., 0 instead of [0]) because else GEMM onnx pass can't be applied
# Tracing needs to be done with the batch size of 1 since we compile our models to FHE with
# this batch size. The input set contains many examples, to determine a representative
# bit-width, but for tracing we only take a single one. We need the ONNX tracing batch size to
# match the batch size during FHE inference which can only be 1 for the moment.
dummy_input_for_tracing = tuple(
torch.from_numpy(val[[0], ::]).float() for val in inputset_as_numpy_tuple
)
Expand Down
75 changes: 57 additions & 18 deletions src/concrete/ml/torch/hybrid_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
from torch import nn
from transformers import Conv1D

from ..common.utils import MAX_BITWIDTH_BACKWARD_COMPATIBLE
from ..deployment.fhe_client_server import FHEModelClient, FHEModelDev
from .compile import QuantizedModule, compile_torch_model


class FHEMode(enum.Enum):
class HybridFHEMode(enum.Enum):
"""Simple enum for different modes of execution of HybridModel."""

DISABLE = "disable" # Use torch weights
Expand Down Expand Up @@ -99,8 +100,7 @@ def __init__(
self.calibration_data: List = []
self.uid = str(uuid.uuid4())
self.private_q_module: Optional[QuantizedModule] = None
# TODO: figure out if this is good
self.fhe_local_mode: FHEMode = FHEMode.CALIBRATE
self.fhe_local_mode: HybridFHEMode = HybridFHEMode.CALIBRATE
self.clients: Dict[str, Tuple[str, FHEModelClient]] = {}
self.path_to_keys: Optional[Path] = None
self.path_to_clients: Optional[Path] = None
Expand All @@ -120,6 +120,7 @@ def init_fhe_client(
Raises:
ValueError: if anything goes wrong with the server.
"""
# Handle paths
self.path_to_clients = path_to_client
if self.path_to_clients is None:
self.path_to_clients = Path() / "clients"
Expand All @@ -129,6 +130,8 @@ def init_fhe_client(
self.path_to_keys = Path() / "keys"
self.path_to_keys.mkdir(exist_ok=True)

# List all shapes supported by the server
# This is needed until we have generic shape support in Concrete Python
assert self.module_name is not None
shapes_response = requests.get(
f"{self.server_remote_address}/list_shapes",
Expand All @@ -139,6 +142,8 @@ def init_fhe_client(
raise ValueError(
f"Couldn't get shapes from server:\n{shapes_response.content.decode('utf-8')}"
)

# For all supported shape we need to get the FHE client from the server
shapes = shapes_response.json()
for shape in shapes:
client_response = requests.get(
Expand Down Expand Up @@ -170,6 +175,7 @@ def init_fhe_client(
print(f"Evaluation keys size: {len(serialized_evaluation_keys) / (10**6):.2f} MB")
assert isinstance(serialized_evaluation_keys, bytes)
assert self.module_name is not None
# Upload the key to the server
response = requests.post(
f"{self.server_remote_address}/add_key",
data={
Expand All @@ -181,46 +187,61 @@ def init_fhe_client(
)
assert response.status_code == 200, response.content.decode("utf-8")
uid = response.json()["uid"]
# We store the key id and the client in the object
# If we observe memory issues due to this we can always move
# towards client lazy loading with caching as done on the server.
self.clients[shape] = (uid, client)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass of the remote module.
To change the behavior of this forward function one must change the fhe_local_mode
attribute. Choices are:
- disable: forward using torch module
- remote: forward with fhe client-server
- simulate: forward with local fhe simulation
- calibrate: forward for calibration
Args:
x (torch.Tensor): The input tensor.
Returns:
(torch.Tensor): The output tensor.
Raises:
ValueError: if fhe_mode is not supported
ValueError: if local_fhe_mode is not supported
"""
# - disable: torch module
# - remote: client-server
# - simulate: compiled simulation
# - calibrate: calibration

if self.fhe_local_mode not in {FHEMode.DISABLE, FHEMode.CALIBRATE, FHEMode.REMOTE, None}:
if self.fhe_local_mode not in {
HybridFHEMode.DISABLE,
HybridFHEMode.CALIBRATE,
HybridFHEMode.REMOTE,
None,
}:
# Using quantized module
assert self.private_q_module is not None
y = torch.Tensor(
self.private_q_module.forward(x.detach().numpy(), fhe=self.fhe_local_mode.value)
)
elif self.fhe_local_mode == FHEMode.DISABLE:
elif self.fhe_local_mode == HybridFHEMode.DISABLE:
# Calling torch
assert self.private_module is not None
y = self.private_module.forward(
x.detach(),
)
assert isinstance(y, torch.Tensor)
elif self.fhe_local_mode == FHEMode.CALIBRATE:
elif self.fhe_local_mode == HybridFHEMode.CALIBRATE:
# Calling torch + gathering calibration data
assert self.private_module is not None
self.calibration_data.append(x.detach())
y = self.private_module(x)
assert isinstance(y, torch.Tensor)
# TODO: https://github.com/zama-ai/concrete-ml-internal/issues/3869
elif self.fhe_local_mode == FHEMode.REMOTE: # pragma:no cover
elif self.fhe_local_mode == HybridFHEMode.REMOTE: # pragma:no cover
# Remote call
y = self.remote_call(x)
else: # pragma:no cover
Expand All @@ -237,14 +258,17 @@ def remote_call(self, x: torch.Tensor) -> torch.Tensor: # pragma:no cover
Returns:
torch.Tensor: The result of the FHE computation
"""
# TODO: https://github.com/zama-ai/concrete-ml-internal/issues/3869
# implement server call and client initialization
# Store tensor device and move to CPU for FHE encryption
base_device = x.device
x = x.to(device="cpu")

# We need to iterate over elements in the batch since
# we don't support batch inference
inferences = []
for index in range(len(x)):
# Manage tensor, tensor shape, and encrypt tensor
clear_input = x[[index], :].detach().numpy()
input_shape = tuple(clear_input.shape)
input_shape = (1,) + tuple(clear_input.shape)
repr_input_shape = str(input_shape[1:])
assert isinstance(clear_input, numpy.ndarray)
assert repr_input_shape in self.clients
Expand All @@ -260,6 +284,7 @@ def remote_call(self, x: torch.Tensor) -> torch.Tensor: # pragma:no cover
assert self.module_name is not None
if self.verbose:
print("Infering ...")
# Inference using FHE server
inference_query = requests.post(
f"{self.server_remote_address}/compute",
files={
Expand All @@ -276,16 +301,30 @@ def remote_call(self, x: torch.Tensor) -> torch.Tensor: # pragma:no cover
end = time.time()
if self.verbose:
print(f"Inference done in {end - start} seconds")
# Unpack the results
# Deserialize and decrypt the result
assert inference_query.status_code == 200, inference_query.content.decode("utf-8")
encrypted_result = inference_query.content
decrypted_prediction = client.deserialize_decrypt_dequantize(encrypted_result)[0]
inferences.append(decrypted_prediction)
# Concatenate results and move them back to proper device
return torch.Tensor(numpy.array(inferences)).to(device=base_device)


# Add support for QAT models
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3992
class HybridFHEModel:
"""Convert a model to a hybrid model."""
"""Convert a model to a hybrid model.
This is done by converting targeted modules by RemoteModules.
This will modify the model in place.
Args:
model (nn.Module): The model to modify (in-place modification)
module_names (Union[str, List[str]]): The module name(s) to replace with FHE server.
server_remote_address): The remote address of the FHE server
model_name (str): Model name identifier
verbose (int): If logs should be printed when interacting with FHE server
"""

def __init__(
self,
Expand Down Expand Up @@ -348,7 +387,7 @@ def __call__(self, x: torch.Tensor, fhe: str = "disable") -> torch.Tensor:
"""
# Set the fhe mode in each remote module
for module in self.remote_modules.values():
module.fhe_local_mode = FHEMode(fhe)
module.fhe_local_mode = HybridFHEMode(fhe)
x = self.model(x)
return x

Expand Down Expand Up @@ -392,9 +431,9 @@ def init_client(
def compile_model(
self,
x: torch.Tensor,
n_bits: int = 8,
rounding_threshold_bits: Optional[int] = 8,
p_error: float = 0.01,
n_bits: int = MAX_BITWIDTH_BACKWARD_COMPATIBLE,
rounding_threshold_bits: Optional[int] = None,
p_error: Optional[float] = None,
configuration: Optional[Configuration] = None,
):
"""Compiles the specific layers to FHE.
Expand All @@ -413,7 +452,7 @@ def compile_model(
# We do a forward pass where we accumulate inputs to use for compilation
for name in self.module_names:
# default is "calibrate"
self.remote_modules[name].fhe_local_mode = FHEMode.CALIBRATE
self.remote_modules[name].fhe_local_mode = HybridFHEMode.CALIBRATE
self.model(x)

self.configuration = configuration
Expand Down
2 changes: 1 addition & 1 deletion tests/torch/test_compile_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,7 +1056,7 @@ def test_net_has_no_tlu(
if num_inputs > 1 and use_qat:
return

use_conv = isinstance(input_shape, tuple) and len(input_shape) == 3
use_conv = isinstance(input_shape, tuple) and len(input_shape) > 1

net = module(use_conv, use_qat, input_shape, n_bits)
net.eval()
Expand Down
6 changes: 3 additions & 3 deletions tests/torch/test_hybrid_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def run_hybrid_model_test(
# Create a hybrid model
hybrid_model = HybridFHEModel(model, module_names)
hybrid_model.compile_model(
inputs, n_bits=8, rounding_threshold_bits=8, configuration=configuration
inputs, p_error=0.01, n_bits=8, rounding_threshold_bits=8, configuration=configuration
)

# Check we can run the simulate locally
Expand Down Expand Up @@ -92,7 +92,7 @@ def run_hybrid_model_test(
def test_gpt2_hybrid_mlp(list_or_str_private_modules_names, expected_accuracy):
"""Test GPT2 hybrid."""

# Get GPT2 from Huggingface
# Get GPT2 from Hugging Face
model_name = "gpt2"
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
Expand All @@ -106,7 +106,7 @@ def test_gpt2_hybrid_mlp(list_or_str_private_modules_names, expected_accuracy):
def test_gpt2_hybrid_mlp_module_not_found():
"""Test GPT2 hybrid."""

# Get GPT2 from Huggingface
# Get GPT2 from Hugging Face
model_name = "gpt2"
model = GPT2LMHeadModel.from_pretrained(model_name)

Expand Down
Loading

0 comments on commit c964cf2

Please sign in to comment.