diff --git a/src/concrete/ml/torch/hybrid_model.py b/src/concrete/ml/torch/hybrid_model.py index 5fee3fee6..07797dea0 100644 --- a/src/concrete/ml/torch/hybrid_model.py +++ b/src/concrete/ml/torch/hybrid_model.py @@ -399,6 +399,19 @@ def _replace_modules(self): ) setattr(parent_module, last, remote_module) + def forward(self, x: torch.Tensor, fhe: str = "disable") -> torch.Tensor: + """Forward pass of the hybrid model. + + Args: + x (torch.Tensor): The input tensor. + fhe (str): The Fully Homomorphic Encryption (FHE) mode (default is "disable"). + + Returns: + torch.Tensor: The output tensor. + """ + self.set_fhe_mode(fhe) + return self.model(x) + def __call__(self, x: torch.Tensor, fhe: str = "disable") -> torch.Tensor: """Call method to run the model locally with a fhe mode. @@ -409,9 +422,7 @@ def __call__(self, x: torch.Tensor, fhe: str = "disable") -> torch.Tensor: Returns: (torch.Tensor): The output tensor. """ - self.set_fhe_mode(fhe) - x = self.model(x) - return x + return self.forward(x, fhe) @staticmethod def _get_module_by_name(model: nn.Module, name: str) -> Union[RemoteModule, nn.Module]: