Skip to content

Commit

Permalink
chore: change save_and_clear_private_info method to consider brevitas…
Browse files Browse the repository at this point in the history
… models
  • Loading branch information
kcelia authored Aug 14, 2024
1 parent 1adacac commit 845d9cf
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/concrete/ml/torch/hybrid_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,10 @@ def save_and_clear_private_info(self, path: Path, via_mlir=True):

# Save the model with a specific filename
model_path = path / "model.pth"
torch.save(self.model, model_path.resolve())

# Save the model state dict due to a Brevitas issue
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4571
torch.save(self.model.state_dict(), model_path.resolve())

# Save the FHE circuit in the same directory
self._save_fhe_circuit(path, via_mlir=via_mlir)
Expand Down
9 changes: 9 additions & 0 deletions tests/torch/test_hybrid_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,15 @@ def test_hybrid_brevitas_qat_model():
hybrid_model = HybridFHEModel(model, module_names="sub_module")
hybrid_model.compile_model(x=inputs)

with tempfile.TemporaryDirectory() as temp_dir:
temp_dir_path = Path(temp_dir)

# Get the temp directory path
hybrid_model.save_and_clear_private_info(temp_dir_path)

# Check that files are there
assert (temp_dir_path / "model.pth").exists()


# Dependency 'huggingface-hub' raises a 'FutureWarning' from version 0.23.0 when calling the
# 'from_pretrained' method
Expand Down

0 comments on commit 845d9cf

Please sign in to comment.