Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[New Feature]: Add support for transposing a tensor along a specified axes #443

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions tenseal/binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,12 @@ void bind_plain_tensor(py::module &m, const std::string &name) {
.def("replicate", &type::replicate)
.def("broadcast", &type::broadcast)
.def("broadcast_", &type::broadcast_inplace)
.def("transpose", &type::transpose)
.def("transpose_", &type::transpose_inplace)
.def("transpose", py::overload_cast<>(&type::transpose, py::const_))
.def("transpose_", py::overload_cast<>(&type::transpose_inplace))
.def("transpose", py::overload_cast<const std::vector<size_t> &>(
&type::transpose, py::const_))
.def("transpose_", py::overload_cast<const std::vector<size_t> &>(
&type::transpose_inplace))
.def("serialize", [](type &obj) { return py::bytes(obj.save()); });
}

Expand Down Expand Up @@ -709,8 +713,12 @@ void bind_ckks_tensor(py::module &m) {
.def("reshape_", &CKKSTensor::reshape_inplace)
.def("broadcast", &CKKSTensor::broadcast)
.def("broadcast_", &CKKSTensor::broadcast_inplace)
.def("transpose", &CKKSTensor::transpose)
.def("transpose_", &CKKSTensor::transpose_inplace)
.def("transpose", py::overload_cast<>(&CKKSTensor::transpose, py::const_))
.def("transpose_", py::overload_cast<>(&CKKSTensor::transpose_inplace))
.def("transpose", py::overload_cast<const std::vector<size_t> &>(
&CKKSTensor::transpose, py::const_))
.def("transpose_", py::overload_cast<const std::vector<size_t> &>(
&CKKSTensor::transpose_inplace))
.def("scale", &CKKSTensor::scale);
}

Expand Down
8 changes: 8 additions & 0 deletions tenseal/cpp/tensors/ckkstensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,14 @@ shared_ptr<CKKSTensor> CKKSTensor::transpose_inplace() {

return shared_from_this();
}
shared_ptr<CKKSTensor> CKKSTensor::transpose(const vector<size_t>& permutation) const {
return this->copy()->transpose_inplace(permutation);
}
shared_ptr<CKKSTensor> CKKSTensor::transpose_inplace(const vector<size_t>& permutation) {
this->_data.transpose_inplace(permutation);

return shared_from_this();
}

double CKKSTensor::scale() const { return _init_scale; }
} // namespace tenseal
2 changes: 2 additions & 0 deletions tenseal/cpp/tensors/ckkstensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ class CKKSTensor : public EncryptedTensor<double, shared_ptr<CKKSTensor>>,

shared_ptr<CKKSTensor> transpose() const;
shared_ptr<CKKSTensor> transpose_inplace();
shared_ptr<CKKSTensor> transpose(const vector<size_t>& permutation) const;
shared_ptr<CKKSTensor> transpose_inplace(const vector<size_t>& permutation);

vector<size_t> shape_with_batch() const;
double scale() const override;
Expand Down
7 changes: 7 additions & 0 deletions tenseal/cpp/tensors/plain_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,13 @@ class PlainTensor {
this->_data.transpose_inplace();
return *this;
}
PlainTensor<plain_t> transpose(const vector<size_t>& permutation) const {
return this->copy().transpose_inplace(permutation);
}
PlainTensor<plain_t>& transpose_inplace(const vector<size_t>& permutation) {
this->_data.transpose_inplace(permutation);
return *this;
}
/**
* Returns the element at position {idx1, idx2, ..., idxn} in the current
* shape
Expand Down
10 changes: 10 additions & 0 deletions tenseal/cpp/tensors/tensor_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,16 @@ class TensorStorage {
this->_data = xt::transpose(this->_data);
return *this;
}

TensorStorage<dtype_t> transpose(const vector<size_t>& axes) const {
return this->copy().transpose_inplace(axes);
}

TensorStorage<dtype_t> transpose_inplace(const vector<size_t>& axes) {
this->_data = xt::transpose(this->_data, axes);
return *this;
}

/**
* Returns the element at position {idx1, idx2, ..., idxn} in the current
* shape
Expand Down
19 changes: 15 additions & 4 deletions tenseal/tensors/ckkstensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,23 @@ def broadcast_(self, shape: List[int]):
self.data.broadcast_(shape)
return self

def transpose(self):
def transpose(self, axes: List[int] = None) -> "CKKSTensor":
"Copies the transpose to a new tensor"
result = self.data.transpose()
result = None
if axes is None:
result = self.data.transpose()
elif isinstance(axes, list) and all(isinstance(x, int) for x in axes):
result = self.data.transpose(axes)
else:
raise TypeError("axes must be a list of integers")
return self._wrap(result)

def transpose_(self):
def transpose_(self, axes: List[int] = None) -> "CKKSTensor":
"Tries to transpose the tensor"
self.data.transpose_()
if axes is None:
self.data.transpose_()
elif isinstance(axes, list) and all(isinstance(x, int) for x in axes):
self.data.transpose_(axes)
else:
raise TypeError("axes must be a list of integers")
return self
18 changes: 14 additions & 4 deletions tenseal/tensors/plaintensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,24 @@ def broadcast_(self, shape: List[int]):
self.data.broadcast_(shape)
return self

def transpose(self):
def transpose(self, axes: List[int] = None):
"Copies the transpose to a new tensor"
new_tensor = PlainTensor(tensor=self.data.data(), shape=self.shape, dtype=self._dtype)
return new_tensor.transpose_()
if axes is None:
return new_tensor.transpose_()
elif isinstance(axes, list) and all(isinstance(x, int) for x in axes):
return new_tensor.transpose_(axes)
else:
raise TypeError("transpose axes must be a list of integers")

def transpose_(self):
def transpose_(self, axes: List[int] = None):
"Tries to transpose the tensor"
self.data.transpose_()
if axes is None:
self.data.transpose_()
elif isinstance(axes, list) and all(isinstance(x, int) for x in axes):
self.data.transpose_(axes)
else:
raise TypeError("transpose axes must be a list of integers")
return self

@classmethod
Expand Down
28 changes: 28 additions & 0 deletions tests/cpp/tensors/ckkstensor_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,34 @@ TEST_P(CKKSTensorTest, TestTranspose) {
ASSERT_TRUE(are_close(decr.data(), {1, 3, 5, 2, 4, 6}));
}

TEST_P(CKKSTensorTest, TestTransposeWithAxes) {
auto enc_type = get<1>(GetParam());

auto ctx = TenSEALContext::Create(scheme_type::ckks, 8192, -1,
{60, 40, 40, 60}, enc_type);
ASSERT_TRUE(ctx != nullptr);
ctx->generate_galois_keys();

auto ldata =
PlainTensor(vector<double>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}),
vector<size_t>({2, 3, 2}));

auto l = CKKSTensor::Create(ctx, ldata, std::pow(2, 40));

// Transpose with specified axes
auto res = l->transpose({0, 2, 1});
ASSERT_THAT(res->shape(), ElementsAreArray({2, 2, 3}));
ASSERT_THAT(l->shape(), ElementsAreArray({2, 3, 2}));
auto decr = res->decrypt();
ASSERT_TRUE(are_close(decr.data(), {1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12}));

// Transpose inplace with specified axes
l->transpose_inplace({0, 2, 1});
ASSERT_THAT(l->shape(), ElementsAreArray({2, 2, 3}));
decr = l->decrypt();
ASSERT_TRUE(are_close(decr.data(), {1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12}));
}

TEST_P(CKKSTensorTest, TestSubscript) {
auto enc_type = get<1>(GetParam());

Expand Down
24 changes: 24 additions & 0 deletions tests/python/tenseal/tensors/test_ckks_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,3 +826,27 @@ def test_transpose(context, data, shape):
assert tensor.shape == list(expected.shape)
result = np.array(tensor.decrypt().tolist())
assert np.allclose(result, expected, rtol=0, atol=0.01)

@pytest.mark.parametrize(
"data, shape, axes",
[
([i for i in range(6)], [1, 2, 3], [0, 2, 1]),
([i for i in range(12)], [2, 2, 3], [0, 2, 1]),
([i for i in range(2 * 3 * 4 * 5)], [2, 3, 4, 5], [0, 3, 2, 1]),
],
)
def test_transpose_with_axes(context, data, shape, axes):
tensor = ts.ckks_tensor(context, ts.plain_tensor(data, shape))

expected = np.transpose(np.array(data).reshape(shape), axes)

newt = tensor.transpose(axes)
assert tensor.shape == shape
assert newt.shape == list(expected.shape)
result = np.array(newt.decrypt().tolist())
assert np.allclose(result, expected, rtol=0, atol=0.01)

tensor.transpose_(axes)
assert tensor.shape == list(expected.shape)
result = np.array(tensor.decrypt().tolist())
assert np.allclose(result, expected, rtol=0, atol=0.01)
22 changes: 22 additions & 0 deletions tests/python/tenseal/tensors/test_plain_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,25 @@ def test_transpose(data, shape):
tensor.transpose_()
assert tensor.shape == list(expected.shape)
assert np.array(tensor.tolist()).any() == expected.any()

@pytest.mark.parametrize(
"data, shape, axes",
[
([i for i in range(6)], [1, 2, 3], [0, 2, 1]),
([i for i in range(12)], [2, 2, 3], [0, 2, 1]),
([i for i in range(2 * 3 * 4 * 5)], [2, 3, 4, 5], [0, 3, 2, 1]),
],
)
def test_transpose(data, shape, axes):
tensor = ts.plain_tensor(data, shape)

expected = np.transpose(np.array(data).reshape(shape), axes)

newt = tensor.transpose(axes)
assert tensor.shape == shape
assert newt.shape == list(expected.shape)
assert np.array(newt.tolist()).any() == expected.any()

tensor.transpose_(axes)
assert tensor.shape == list(expected.shape)
assert np.array(tensor.tolist()).any() == expected.any()