diff --git a/src/concrete/ml/torch/hybrid_model.py b/src/concrete/ml/torch/hybrid_model.py index 3b54bca42..0da5eadee 100644 --- a/src/concrete/ml/torch/hybrid_model.py +++ b/src/concrete/ml/torch/hybrid_model.py @@ -37,6 +37,7 @@ class HybridFHEMode(enum.Enum): SIMULATE = "simulate" # Use FHE simulation CALIBRATE = "calibrate" # Use calibration (to run before FHE compilation) EXECUTE = "execute" # Use FHE execution + TORCH = "torch" # Use torch layers def tuple_to_underscore_str(tup: Tuple) -> str: @@ -236,15 +237,15 @@ def forward(self, x: torch.Tensor) -> Union[torch.Tensor, QuantTensor]: Raises: ValueError: if local_fhe_mode is not supported """ - # - disable: torch module + # - disable: quantized module # - remote: client-server # - simulate: compiled simulation # - calibrate: calibration if self.fhe_local_mode not in { - HybridFHEMode.DISABLE, HybridFHEMode.CALIBRATE, HybridFHEMode.REMOTE, + HybridFHEMode.TORCH, None, }: # Using quantized module @@ -253,14 +254,6 @@ def forward(self, x: torch.Tensor) -> Union[torch.Tensor, QuantTensor]: self.private_q_module.forward(x.detach().numpy(), fhe=self.fhe_local_mode.value) ) - elif self.fhe_local_mode == HybridFHEMode.DISABLE: - # Calling torch - assert self.private_module is not None - y = self.private_module.forward( - x.detach(), - ) - assert isinstance(y, (QuantTensor, torch.Tensor)) - elif self.fhe_local_mode == HybridFHEMode.CALIBRATE: # Calling torch + gathering calibration data assert self.private_module is not None @@ -271,7 +264,10 @@ def forward(self, x: torch.Tensor) -> Union[torch.Tensor, QuantTensor]: elif self.fhe_local_mode == HybridFHEMode.REMOTE: # pragma:no cover # Remote call y = self.remote_call(x) - + elif self.fhe_local_mode == HybridFHEMode.TORCH: + # Using torch layers + assert self.private_module is not None + y = self.private_module(x) else: # pragma:no cover # Shouldn't happen raise ValueError(f"{self.fhe_local_mode} is not recognized") @@ -369,7 +365,7 @@ def __init__( name: self._get_module_by_name(self.model, name) for name in self.module_names } self.remote_modules: Dict[str, RemoteModule] = {} - self.private_q_modules: dict = {} + self.private_q_modules: Dict[str, QuantizedModule] = {} self.configuration: Optional[Configuration] = None self.model_name = model_name self.verbose = verbose @@ -413,9 +409,7 @@ def __call__(self, x: torch.Tensor, fhe: str = "disable") -> torch.Tensor: Returns: (torch.Tensor): The output tensor. """ - # Set the fhe mode in each remote module - for module in self.remote_modules.values(): - module.fhe_local_mode = HybridFHEMode(fhe) + self.set_fhe_mode(fhe) x = self.model(x) return x @@ -478,9 +472,7 @@ def compile_model( encryption parameters. If not specified, a default configuration is used. """ # We do a forward pass where we accumulate inputs to use for compilation - for name in self.module_names: - # default is "calibrate" - self.remote_modules[name].fhe_local_mode = HybridFHEMode.CALIBRATE + self.set_fhe_mode(HybridFHEMode.CALIBRATE) self.model(x) self.configuration = configuration @@ -523,12 +515,15 @@ def _save_fhe_circuit(self, path: Path, via_mlir=False): model_path = Path(path) for module_name in self.module_names: + onnx_model = self.private_q_modules[module_name].onnx_model + + # mypy + assert onnx_model is not None input_shapes = [ tuple(elt.dim_value for elt in onnx_input.type.tensor_type.shape.dim) - for onnx_input in self.private_q_modules[ # pylint: disable=protected-access - self.module_names[0] - ]._onnx_model.graph.input + for onnx_input in onnx_model.graph.input ] + assert len(input_shapes) == 1, "Multi-input circuits not supported yet" model_module_path = model_path.resolve() / module_name model_module_path.mkdir(exist_ok=True) diff --git a/tests/torch/test_hybrid_converter.py b/tests/torch/test_hybrid_converter.py index 40611d5bc..c4d655d42 100644 --- a/tests/torch/test_hybrid_converter.py +++ b/tests/torch/test_hybrid_converter.py @@ -50,7 +50,7 @@ def run_hybrid_llm_test( # Create a hybrid model hybrid_model = HybridFHEModel(model, module_names) hybrid_model.compile_model( - inputs, p_error=0.01, n_bits=8, rounding_threshold_bits=8, configuration=configuration + inputs, p_error=0.1, n_bits=9, rounding_threshold_bits=8, configuration=configuration ) if has_pbs: @@ -71,35 +71,34 @@ def run_hybrid_llm_test( # Check we can run the simulate locally logits_simulate = hybrid_model(inputs, fhe="simulate").logits logits_disable = hybrid_model(inputs, fhe="disable").logits - logits_original = model(inputs).logits + logits_original = hybrid_model(inputs, fhe="torch").logits # Ensure logits_disable and logits_original return the same output for the logits - assert torch.allclose(logits_disable, logits_original, atol=1e-7), "Outputs do not match!" + assert torch.allclose(logits_disable, logits_simulate, atol=1e-7), "Outputs do not match!" # Compare the topk accuracy of the FHE simulate circuit vs. the original. - k = 100 + k = 5 + + # Check that the topk next tokens are similar for the different FHE modes + # and the original model. # Get the topk indices for logits_disable and logits_simulate topk_disable = logits_disable.topk(k, dim=-1).indices topk_simulate = logits_simulate.topk(k, dim=-1).indices + topk_original = logits_original.topk(k, dim=-1).indices - # Prepare tensors for broadcasting - expanded_simulate = topk_simulate.unsqueeze(-1) - expanded_disable = topk_disable.unsqueeze(-2) - - # Compute if elements of topk_simulate are in topk_disable for each token - (expanded_simulate == expanded_disable).any(-1) - - # Make sure accuracy is above a certain threshold - # Even with a small tolerance the test is flaky - # Commenting the assertion for now until issue is resolved - # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3905 + # Compute accuracy of disable and simulate by checking + # how many labels correspond with the topk_original + accuracy_disable = (topk_disable == topk_original).float().mean().item() + accuracy_simulate = (topk_simulate == topk_original).float().mean().item() - # Compute average of these counts (the accuracy) - # accuracy = is_in.float().mean() - # To use expected accuracy until the check is done - assert expected_accuracy > -1 - # assert accuracy >= expected_accuracy, "Expected accuracy GPT2 hybrid not matched." + # Assert that both accuracy values are above the expected threshold + assert ( + accuracy_disable >= expected_accuracy + ), f"Disable accuracy {accuracy_disable:.4f} is below the expected {expected_accuracy:.4f}" + assert ( + accuracy_simulate >= expected_accuracy + ), f"Simulate accuracy {accuracy_simulate:.4f} is below the expected {expected_accuracy:.4f}" with tempfile.TemporaryDirectory() as temp_dir: temp_dir_path = Path(temp_dir) @@ -127,9 +126,9 @@ def run_hybrid_llm_test( @pytest.mark.parametrize( "list_or_str_private_modules_names, expected_accuracy, has_pbs", [ - ("transformer.h.0.mlp", 0.934, True), - (["transformer.h.0.mlp", "transformer.h.1.mlp"], 0.42, True), - ("transformer.h.0.mlp.c_fc", 0.986, False), + ("transformer.h.0.mlp", 0.95, True), + (["transformer.h.0.mlp", "transformer.h.1.mlp"], 0.40, True), + ("transformer.h.0.mlp.c_fc", 1.0, False), ], ) def test_gpt2_hybrid_mlp(list_or_str_private_modules_names, expected_accuracy, has_pbs):