From 4b8cc4cbd6bebc6f4f7d23a4c15102d7587ee339 Mon Sep 17 00:00:00 2001 From: khoaguin Date: Wed, 22 May 2024 21:10:25 +0700 Subject: [PATCH] [syft/serde] - move all code related to torch tensor serde into a try-exception block in `third_party.py` - move torch back to under data_science marker in setup.cfg --- packages/syft/setup.cfg | 4 +- packages/syft/src/syft/serde/array.py | 18 ------- packages/syft/src/syft/serde/third_party.py | 58 ++++++++++++++------- 3 files changed, 40 insertions(+), 40 deletions(-) diff --git a/packages/syft/setup.cfg b/packages/syft/setup.cfg index dbb7b19c613..0418fb6b0af 100644 --- a/packages/syft/setup.cfg +++ b/packages/syft/setup.cfg @@ -66,8 +66,6 @@ syft = rich==13.7.1 jinja2==3.1.4 tenacity==8.3.0 - # backend.dockerfile installs torch separately, so update the version over there as well! - torch==2.3.0 install_requires = %(syft)s @@ -88,6 +86,8 @@ data_science = opendp==0.9.2 evaluate==0.4.1 recordlinkage==0.16 + # backend.dockerfile installs torch separately, so update the version over there as well! + torch==2.3.0 dev = %(test_plugins)s diff --git a/packages/syft/src/syft/serde/array.py b/packages/syft/src/syft/serde/array.py index fa8db0d9932..56158225857 100644 --- a/packages/syft/src/syft/serde/array.py +++ b/packages/syft/src/syft/serde/array.py @@ -1,7 +1,6 @@ # third party import numpy as np from numpy import frombuffer -import torch # relative from .arrow import numpy_deserialize @@ -151,22 +150,5 @@ # deserialize=lambda buffer: frombuffer(buffer, dtype=numpy_scalar_type), # ) - -# Add support for torch tensors -def torch_serialize(tensor: torch.Tensor) -> bytes: - return numpy_serialize(tensor.numpy()) - - -def torch_deserialize(buffer: bytes) -> torch.tensor: - np_array = numpy_deserialize(buffer) - return torch.from_numpy(np_array) - - -recursive_serde_register( - torch.Tensor, - serialize=torch_serialize, - deserialize=lambda data: torch_deserialize(data), -) - # how else do you import a relative file to execute it? NOTHING = None diff --git a/packages/syft/src/syft/serde/third_party.py b/packages/syft/src/syft/serde/third_party.py index be4750f5794..cec967b3526 100644 --- a/packages/syft/src/syft/serde/third_party.py +++ b/packages/syft/src/syft/serde/third_party.py @@ -24,14 +24,14 @@ from result import Err from result import Ok from result import Result -import torch -from torch._C import _TensorMeta # relative from ..types.dicttuple import DictTuple from ..types.dicttuple import _Meta as _DictTupleMetaClass from ..types.syft_metaclass import EmptyType from ..types.syft_metaclass import PartialModelMetaclass +from .array import numpy_deserialize +from .array import numpy_serialize from .deserialize import _deserialize as deserialize from .recursive_primitives import _serialize_kv_pairs from .recursive_primitives import deserialize_kv @@ -107,24 +107,6 @@ def deserialize_series(blob: bytes) -> Series: deserialize=deserialize_series, ) - -def serialize_torch_tensor_meta(t: _TensorMeta) -> bytes: - buffer = BytesIO() - torch.save(t, buffer) - return buffer.getvalue() - - -def deserialize_torch_tensor_meta(buf: bytes) -> _TensorMeta: - buffer = BytesIO(buf) - return torch.load(buffer) - - -recursive_serde_register( - _TensorMeta, - serialize=serialize_torch_tensor_meta, - deserialize=deserialize_torch_tensor_meta, -) - recursive_serde_register( datetime, serialize=lambda x: serialize(x.isoformat(), to_bytes=True), @@ -198,6 +180,42 @@ def serialize_bytes_io(io: BytesIO) -> bytes: pass +try: + # third party + import torch + from torch._C import _TensorMeta + + def serialize_torch_tensor_meta(t: _TensorMeta) -> bytes: + buffer = BytesIO() + torch.save(t, buffer) + return buffer.getvalue() + + def deserialize_torch_tensor_meta(buf: bytes) -> _TensorMeta: + buffer = BytesIO(buf) + return torch.load(buffer) + + recursive_serde_register( + _TensorMeta, + serialize=serialize_torch_tensor_meta, + deserialize=deserialize_torch_tensor_meta, + ) + + def torch_serialize(tensor: torch.Tensor) -> bytes: + return numpy_serialize(tensor.numpy()) + + def torch_deserialize(buffer: bytes) -> torch.tensor: + np_array = numpy_deserialize(buffer) + return torch.from_numpy(np_array) + + recursive_serde_register( + torch.Tensor, + serialize=torch_serialize, + deserialize=lambda data: torch_deserialize(data), + ) + +except Exception: # nosec + pass + # unsure why we have to register the object not the type but this works recursive_serde_register(np.core._ufunc_config._unspecified())