From 331d60b91e23602c81c6a95183e23ec4610b3a04 Mon Sep 17 00:00:00 2001 From: youben11 Date: Thu, 28 Nov 2024 11:43:17 +0100 Subject: [PATCH] perf(frontend/compiler): deser keyset using path instead of buffer reduce memory usage by avoiding unecessary copy --- .../lib/Bindings/Python/CompilerAPIModule.cpp | 20 +++++++++++++++++++ .../concrete/fhe/compilation/keys.py | 12 +++++------ 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index c1183a8128..0d92749237 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -1433,6 +1433,26 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( return std::make_unique(std::move(keyset)); }, "Deserialize a Keyset from bytes.", arg("bytes")) + .def_static( + "deserialize_from_file", + [](const std::string path) { + std::ifstream ifs; + ifs.open(path); + if (!ifs.good()) { + throw std::runtime_error("Failed to open keyset file " + path); + } + + auto keysetProto = Message(); + auto maybeError = keysetProto.readBinaryFromIstream( + ifs, mlir::concretelang::python::DESER_OPTIONS); + if (maybeError.has_failure()) { + throw std::runtime_error("Failed to deserialize keyset." + + maybeError.as_failure().error().mesg); + } + auto keyset = Keyset::fromProto(keysetProto); + return std::make_unique(std::move(keyset)); + }, + "Deserialize a Keyset from a file.", arg("path")) .def( "serialize", [](Keyset &keySet) { diff --git a/frontends/concrete-python/concrete/fhe/compilation/keys.py b/frontends/concrete-python/concrete/fhe/compilation/keys.py index 72942cb4c1..9ab347330f 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/keys.py +++ b/frontends/concrete-python/concrete/fhe/compilation/keys.py @@ -133,7 +133,7 @@ def load(self, location: Union[str, Path]): message = f"Unable to load keys from {location} because it doesn't exist" raise ValueError(message) - keys = Keys.deserialize(bytes(location.read_bytes())) + keys = Keys.deserialize(location) # pylint: disable=protected-access self._specs = None @@ -185,20 +185,20 @@ def serialize(self) -> bytes: return serialized_keyset @staticmethod - def deserialize(serialized_keys: bytes) -> "Keys": + def deserialize(path: Path) -> "Keys": """ - Deserialize keys from bytes. + Deserialize keys from file. Args: - serialized_keys (bytes): - previously serialized keys + path (Path): + previously serialized keys path Returns: Keys: deserialized keys """ - keyset = Keyset.deserialize(serialized_keys) + keyset = Keyset.deserialize_from_file(str(path)) # pylint: disable=protected-access result = Keys(None)