Skip to content

Commit

Permalink
chore: add forward method to hybrid model (#846)
Browse files Browse the repository at this point in the history
  • Loading branch information
jfrery authored Aug 29, 2024
1 parent fd14a5e commit ee58a68
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions src/concrete/ml/torch/hybrid_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,19 @@ def _replace_modules(self):
)
setattr(parent_module, last, remote_module)

def forward(self, x: torch.Tensor, fhe: str = "disable") -> torch.Tensor:
"""Forward pass of the hybrid model.
Args:
x (torch.Tensor): The input tensor.
fhe (str): The Fully Homomorphic Encryption (FHE) mode (default is "disable").
Returns:
torch.Tensor: The output tensor.
"""
self.set_fhe_mode(fhe)
return self.model(x)

def __call__(self, x: torch.Tensor, fhe: str = "disable") -> torch.Tensor:
"""Call method to run the model locally with a fhe mode.
Expand All @@ -409,9 +422,7 @@ def __call__(self, x: torch.Tensor, fhe: str = "disable") -> torch.Tensor:
Returns:
(torch.Tensor): The output tensor.
"""
self.set_fhe_mode(fhe)
x = self.model(x)
return x
return self.forward(x, fhe)

@staticmethod
def _get_module_by_name(model: nn.Module, name: str) -> Union[RemoteModule, nn.Module]:
Expand Down

0 comments on commit ee58a68

Please sign in to comment.