Skip to content

Commit 70d65d3

Browse files
Add to_row_matrix and to_column_matrix at array
1 parent 41999a3 commit 70d65d3

File tree

2 files changed

+30
-11
lines changed

2 files changed

+30
-11
lines changed

Compiler/types.py

+22
Original file line numberDiff line numberDiff line change
@@ -6322,6 +6322,28 @@ def sort(self, n_threads=None, batcher=False, n_bits=None):
63226322
from . import sorting
63236323
sorting.radix_sort(self, self, n_bits=n_bits)
63246324

6325+
def to_row_matrix(self):
6326+
"""
6327+
Returns the array as 1xN matrix.
6328+
6329+
Warning: This operation is in-place (without copying data), i.e., all changes to the values of the matrix will also affect the original array.
6330+
:return: Matrix
6331+
"""
6332+
assert self.value_type.n_elements() == 1 and \
6333+
self.value_type.mem_size() == 1
6334+
return Matrix(1, self.length, self.value_type, address=self.address)
6335+
6336+
def to_column_matrix(self):
6337+
"""
6338+
Returns the array as Nx1 matrix.
6339+
6340+
Warning: This operation is in-place (without copying data), i.e., all changes to the values of the matrix will also affect the original array.
6341+
:return: Matrix
6342+
"""
6343+
assert self.value_type.n_elements() == 1 and \
6344+
self.value_type.mem_size() == 1
6345+
return Matrix(self.length, 1, self.value_type, address=self.address)
6346+
63256347
def Array(self, size):
63266348
# compatibility with registers
63276349
return Array(size, self.value_type)

Programs/Source/test_dot.mpc

+8-11
Original file line numberDiff line numberDiff line change
@@ -49,35 +49,32 @@ def test_matrix(expected, actual):
4949

5050
crash()
5151

52-
break_point()
53-
def hacky_array_dot_matrix(arr, mat):
54-
# Arrays sadly do not have a dot function, therefore the array is converted into a 1 times n Matrix by copying memory addresses.
55-
tmp = sint.Matrix(rows=1, columns=len(arr), address=arr.address)
56-
result = tmp.dot(mat)
57-
return sint.Array(mat.shape[1], result.address)
58-
5952
start_timer(3)
6053

61-
e3 = hacky_array_dot_matrix(a, c)
54+
e3 = a.to_row_matrix().dot(c).to_array()
6255
# b[0] = e3[0]
63-
f3 = hacky_array_dot_matrix(b, d)
56+
f3 = b.to_row_matrix().dot(d).to_array()
57+
g3 = c.dot(b.to_column_matrix()).to_array()
6458

6559
stop_timer(3)
6660

6761
e3 = e3.reveal()
6862
f3 = f3.reveal()
63+
g3 = g3.reveal()
6964

7065
e3.print_reveal_nested()
7166
f3.print_reveal_nested()
67+
g3.print_reveal_nested()
7268

7369
test_array([70, 80, 90], e3)
7470
test_array([56, 50, 44, 38], f3)
71+
test_array([10, 28, 46, 64], g3)
7572

7673
start_timer(4)
7774

78-
e4 = hacky_array_dot_matrix(a, c)
75+
e4 = a.to_row_matrix().dot(c).to_array()
7976
b[-1] = e4[0]
80-
f4 = hacky_array_dot_matrix(b, d)
77+
f4 = b.to_row_matrix().dot(d).to_array()
8178

8279
stop_timer(4)
8380

0 commit comments

Comments
 (0)