From 6b5a4d0a02d0fcbf877744713cffc6a9c0a363e9 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Mon, 18 Mar 2024 18:18:47 +0100 Subject: [PATCH 1/4] deserialize MutMatrix --- osiris/cairo/serde/deserialize.py | 28 ++++++++++++++++++++++++++++ tests/test_deserialize.py | 5 +++++ 2 files changed, 33 insertions(+) diff --git a/osiris/cairo/serde/deserialize.py b/osiris/cairo/serde/deserialize.py index f7158e7..2c7a164 100644 --- a/osiris/cairo/serde/deserialize.py +++ b/osiris/cairo/serde/deserialize.py @@ -1,4 +1,5 @@ import numpy as np +import re from osiris.cairo.serde.utils import felt_to_int, from_fp @@ -17,6 +18,9 @@ def deserializer(serialized, dtype): elif dtype.startswith('Tensor<'): return deserialize_tensor(serialized, dtype) + elif dtype.startswith('MutMatrix<'): + return deserialize_matrix(serialized, dtype) + elif dtype.startswith('('): # Tuple return deserialize_tuple(serialized, dtype) @@ -80,6 +84,30 @@ def deserialize_tuple(serialized, dtype): return part1, part2 +def deserialize_matrix(serialized, dtype): + + # Extract inner dtype + pattern = r"<(.*)>" + inner_dtype = re.search(pattern, dtype).group(1) + + # Extract the matrix content and shape from the serialized string + content, shape_str = serialized.split("} ") + # Last two numbers are the shape + shape = tuple(map(int, shape_str.split()[-2:])) + + # Use regex to find all occurrences of ': ' followed by any characters until the next ' :' or end of string + pattern = r': (.*?)(?=\s\d+: |$)' + elements = re.findall(pattern, content) + + # Deserialize each element using the appropriate deserializer based on dtype + deserialized_elements = [deserializer( + element, inner_dtype) for element in elements] + + # Reshape the deserialized elements into a numpy array of the specified shape + matrix = np.array(deserialized_elements).reshape(shape) + return matrix + + def find_nth_occurrence(string, sub_string, n): start_index = string.find(sub_string) while start_index >= 0 and n > 1: diff --git a/tests/test_deserialize.py b/tests/test_deserialize.py index e78502e..0c4eaff 100644 --- a/tests/test_deserialize.py +++ b/tests/test_deserialize.py @@ -59,6 +59,11 @@ def test_deserialize_tensor_fixed_point(): deserialized = deserializer(serialized, 'Tensor') assert np.allclose(deserialized, expected_array, atol=1e-7) +def test_deserialize_matrix_fixed_point(): + serialized = "{0: 2780037 false 2: 2780037 false 1: 2780037 true 3: 2780037 true} 4 2 2" + expected_array = np.array([[42.42, 42.42], [-42.42, -42.42]]) + deserialized = deserializer(serialized, 'MutMatrix') + assert np.allclose(deserialized, expected_array, atol=1e-7) def test_deserialize_tuple_int(): serialized = '1 3' From 74cbc027185ebb0e4200353ac2635705f96749fe Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Mon, 18 Mar 2024 18:19:35 +0100 Subject: [PATCH 2/4] update fp deserializer --- osiris/cairo/serde/deserialize.py | 2 +- tests/test_deserialize.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/osiris/cairo/serde/deserialize.py b/osiris/cairo/serde/deserialize.py index 2c7a164..0b85423 100644 --- a/osiris/cairo/serde/deserialize.py +++ b/osiris/cairo/serde/deserialize.py @@ -31,7 +31,7 @@ def deserializer(serialized, dtype): def deserialize_fp(serialized): parts = serialized.split() value = from_fp(int(parts[0])) - if len(parts) > 1 and parts[1] == '1': # Check for negative sign + if len(parts) > 1 and parts[1] == 'true': # Check for negative sign value = -value return value diff --git a/tests/test_deserialize.py b/tests/test_deserialize.py index 0c4eaff..94c07b0 100644 --- a/tests/test_deserialize.py +++ b/tests/test_deserialize.py @@ -16,11 +16,11 @@ def test_deserialize_int(): def test_deserialize_fp(): - serialized = '2780037 0' + serialized = '2780037 false' deserialized = deserializer(serialized, 'FP16x16') assert isclose(deserialized, 42.42, rel_tol=1e-7) - serialized = '2780037 1' + serialized = '2780037 true' deserialized = deserializer(serialized, 'FP16x16') assert isclose(deserialized, -42.42, rel_tol=1e-7) @@ -36,7 +36,7 @@ def test_deserialize_array_int(): def test_deserialize_arr_fixed_point(): - serialized = '[2780037 0 2780037 1]' + serialized = '[2780037 false 2780037 true]' deserialized = deserializer(serialized, 'Span') expected = np.array([42.42, -42.42], dtype=np.float64) assert np.all(np.isclose(deserialized, expected, atol=1e-7)) @@ -54,7 +54,7 @@ def test_deserialize_tensor_int(): def test_deserialize_tensor_fixed_point(): - serialized = '[2 2] [2780037 0 2780037 0 2780037 1 2780037 1]' + serialized = '[2 2] [2780037 false 2780037 false 2780037 true 2780037 true]' expected_array = np.array([[42.42, 42.42], [-42.42, -42.42]]) deserialized = deserializer(serialized, 'Tensor') assert np.allclose(deserialized, expected_array, atol=1e-7) @@ -80,13 +80,13 @@ def test_deserialize_tuple_span(): def test_deserialize_tuple_span_tensor_fp(): - serialized = '[1 2] [2 2] [2780037 0 2780037 0 2780037 1 2780037 1]' + serialized = '[1 2] [2 2] [2780037 false 2780037 false 2780037 true 2780037 true]' deserialized = deserializer(serialized, '(Span, Tensor)') expected = (np.array([1, 2]), np.array([[42.42, 42.42], [-42.42, -42.42]])) npt.assert_array_equal(deserialized[0], expected[0]) assert np.allclose(deserialized[1], expected[1], atol=1e-7) - serialized = '[2 2] [2780037 0 2780037 0 2780037 1 2780037 1] [1 2]' + serialized = '[2 2] [2780037 false 2780037 false 2780037 true 2780037 true] [1 2]' deserialized = deserializer(serialized, '(Tensor, Span)') expected = (np.array([[42.42, 42.42], [-42.42, -42.42]]), np.array([1, 2])) assert np.allclose(deserialized[0], expected[0], atol=1e-7) From d2c76522bdd4a6c878039b646790c2e6e25bd6cc Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Mon, 18 Mar 2024 18:30:40 +0100 Subject: [PATCH 3/4] deserialize tuple of span and matrix --- osiris/cairo/serde/deserialize.py | 40 ++++++++++++++++++------------- tests/test_deserialize.py | 7 ++++++ 2 files changed, 31 insertions(+), 16 deletions(-) diff --git a/osiris/cairo/serde/deserialize.py b/osiris/cairo/serde/deserialize.py index 0b85423..a3335b7 100644 --- a/osiris/cairo/serde/deserialize.py +++ b/osiris/cairo/serde/deserialize.py @@ -64,24 +64,32 @@ def deserialize_tensor(serialized, dtype): def deserialize_tuple(serialized, dtype): types = dtype[1:-1].split(', ') - if 'Tensor' in types[0]: - tensor_end = find_nth_occurrence(serialized, ']', 2) - depth = 1 - for i in range(tensor_end, len(serialized)): - if serialized[i] == '[': - depth += 1 - elif serialized[i] == ']': - depth -= 1 - if depth == 0: - tensor_end = i + 1 - break - part1 = deserializer(serialized[:tensor_end].strip(), types[0]) - part2 = deserializer(serialized[tensor_end:].strip(), types[1]) - else: - split_index = serialized.find(']') + 2 + # Check if there is no space between span and matrix. + is_no_space = re.search(r']\{', serialized) + if is_no_space: + split_index = is_no_space.start() + 1 part1 = deserializer(serialized[:split_index].strip(), types[0]) part2 = deserializer(serialized[split_index:].strip(), types[1]) - return part1, part2 + return part1, part2 + else: + if 'Tensor' in types[0]: + tensor_end = find_nth_occurrence(serialized, ']', 2) + depth = 1 + for i in range(tensor_end, len(serialized)): + if serialized[i] == '[': + depth += 1 + elif serialized[i] == ']': + depth -= 1 + if depth == 0: + tensor_end = i + 1 + break + part1 = deserializer(serialized[:tensor_end].strip(), types[0]) + part2 = deserializer(serialized[tensor_end:].strip(), types[1]) + else: + split_index = serialized.find(']') + 2 + part1 = deserializer(serialized[:split_index].strip(), types[0]) + part2 = deserializer(serialized[split_index:].strip(), types[1]) + return part1, part2 def deserialize_matrix(serialized, dtype): diff --git a/tests/test_deserialize.py b/tests/test_deserialize.py index 94c07b0..ae53c45 100644 --- a/tests/test_deserialize.py +++ b/tests/test_deserialize.py @@ -91,3 +91,10 @@ def test_deserialize_tuple_span_tensor_fp(): expected = (np.array([[42.42, 42.42], [-42.42, -42.42]]), np.array([1, 2])) assert np.allclose(deserialized[0], expected[0], atol=1e-7) npt.assert_array_equal(deserialized[1], expected[1]) + +def test_deserialize_tuple_matrix_fp(): + serialized = '[1 2]{0: 2780037 false 2: 2780037 false 1: 2780037 true 3: 2780037 true} 4 2 2' + deserialized = deserializer(serialized, '(Span, MutMatrix)') + expected = (np.array([1, 2]), np.array([[42.42, 42.42], [-42.42, -42.42]])) + npt.assert_array_equal(deserialized[0], expected[0]) + assert np.allclose(deserialized[1], expected[1], atol=1e-7) \ No newline at end of file From 2664a69287849d2edd01ad1da26b31e02ad1b843 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Mon, 18 Mar 2024 18:31:48 +0100 Subject: [PATCH 4/4] fix lint --- osiris/cairo/serde/deserialize.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/osiris/cairo/serde/deserialize.py b/osiris/cairo/serde/deserialize.py index a3335b7..ec7cd41 100644 --- a/osiris/cairo/serde/deserialize.py +++ b/osiris/cairo/serde/deserialize.py @@ -1,6 +1,7 @@ -import numpy as np import re +import numpy as np + from osiris.cairo.serde.utils import felt_to_int, from_fp