From ee58a68f5490e1c5ac7084e321ee12fdad941ca6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jordan=20Fr=C3=A9ry?= Date: Thu, 29 Aug 2024 17:26:56 +0200 Subject: [PATCH] chore: add forward method to hybrid model (#846) --- src/concrete/ml/torch/hybrid_model.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/concrete/ml/torch/hybrid_model.py b/src/concrete/ml/torch/hybrid_model.py index 5fee3fee6..07797dea0 100644 --- a/src/concrete/ml/torch/hybrid_model.py +++ b/src/concrete/ml/torch/hybrid_model.py @@ -399,6 +399,19 @@ def _replace_modules(self): ) setattr(parent_module, last, remote_module) + def forward(self, x: torch.Tensor, fhe: str = "disable") -> torch.Tensor: + """Forward pass of the hybrid model. + + Args: + x (torch.Tensor): The input tensor. + fhe (str): The Fully Homomorphic Encryption (FHE) mode (default is "disable"). + + Returns: + torch.Tensor: The output tensor. + """ + self.set_fhe_mode(fhe) + return self.model(x) + def __call__(self, x: torch.Tensor, fhe: str = "disable") -> torch.Tensor: """Call method to run the model locally with a fhe mode. @@ -409,9 +422,7 @@ def __call__(self, x: torch.Tensor, fhe: str = "disable") -> torch.Tensor: Returns: (torch.Tensor): The output tensor. """ - self.set_fhe_mode(fhe) - x = self.model(x) - return x + return self.forward(x, fhe) @staticmethod def _get_module_by_name(model: nn.Module, name: str) -> Union[RemoteModule, nn.Module]: