Skip to content

Commit

Permalink
fix serialize unsigned
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaelDkhn committed Dec 23, 2023
1 parent fac6987 commit 31b4b42
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 3 deletions.
5 changes: 4 additions & 1 deletion osiris/cairo/serde/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ def __init__(self, shape: tuple, data):
self.shape = shape
self.data = data

class UnsignedInt:
def __init__(self, mag):
self.mag = mag

class SignedInt:
def __init__(self, mag, sign):
Expand All @@ -30,7 +33,7 @@ def create_tensor_from_array(arr: np.ndarray, fp_impl='FP16x16'):

for value in flat_array:
if isinstance(value, (np.unsignedinteger)):
tensor_data.append(value)
tensor_data.append(UnsignedInt(value))
elif isinstance(value, (int, np.integer, np.signedinteger)):
sign = 0 if value >= 0 else 1
tensor_data.append(SignedInt(abs(value), sign))
Expand Down
5 changes: 4 additions & 1 deletion osiris/cairo/serde/serialize.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from osiris.cairo.serde.data_structures import FixedPoint, SignedInt, Tensor
from osiris.cairo.serde.data_structures import FixedPoint, SignedInt, Tensor, UnsignedInt


def serializer(data) -> list[str]:
Expand All @@ -23,5 +23,8 @@ def serializer(data) -> list[str]:
return serialized_tensor
elif isinstance(data, (SignedInt, FixedPoint)):
return [str(data.mag), str(data.sign)]
elif isinstance(data, UnsignedInt):
return [str(data.mag)]

else:
raise ValueError("Unsupported data type for serialization")
9 changes: 8 additions & 1 deletion tests/test_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,15 @@ def test_create_tensor_from_array_with_invalid_input():
create_tensor_from_array("not a numpy array")


def test_serializer_for_tensor():
def test_serializer_for_tensor_signedint():
arr = np.array([[1, 2], [3, 4]], dtype=np.int64)
tensor = create_tensor_from_array(arr)
serialized_data = serializer(tensor)
assert isinstance(serialized_data, list)


def test_serializer_for_tensor_uint():
arr = np.array([[1, 2], [3, 4]], dtype=np.uint64)
tensor = create_tensor_from_array(arr)
serialized_data = serializer(tensor)
assert isinstance(serialized_data, list)

0 comments on commit 31b4b42

Please sign in to comment.