From fd53d16f987b81db361bebb2124d5bb373414982 Mon Sep 17 00:00:00 2001 From: Hakeem Kazeem Date: Tue, 19 Sep 2023 10:31:43 +0100 Subject: [PATCH] fixed gather i8, i32 , u32 --- nodegen/node/gather.py | 195 ++++++++++++++++-- src/operators/tensor/core.cairo | 2 +- .../implementations/tensor_fp16x16.cairo | 2 +- .../implementations/tensor_fp32x32.cairo | 2 +- .../implementations/tensor_fp64x64.cairo | 2 +- .../implementations/tensor_fp8x23.cairo | 2 +- .../tensor/implementations/tensor_i32.cairo | 5 +- .../tensor/implementations/tensor_i8.cairo | 5 +- .../tensor/implementations/tensor_u32.cairo | 7 +- src/operators/tensor/math/gather.cairo | 26 +-- src/tests/nodes.cairo | 16 +- .../gather_fp16x16_3d_axis1/input_1.cairo | 18 +- .../gather_fp16x16_3d_axis2/input_1.cairo | 18 +- .../gather_fp16x16_3d_default/input_1.cairo | 18 +- .../gather_fp8x23_3d_axis1/input_1.cairo | 18 +- .../gather_fp8x23_3d_axis2/input_1.cairo | 18 +- .../gather_fp8x23_3d_default/input_1.cairo | 18 +- src/tests/nodes/gather_i32_3d_axis1.cairo | 22 ++ .../nodes/gather_i32_3d_axis1/input_0.cairo | 41 ++++ .../nodes/gather_i32_3d_axis1/input_1.cairo | 18 ++ .../nodes/gather_i32_3d_axis1/output_0.cairo | 69 +++++++ src/tests/nodes/gather_i32_3d_axis2.cairo | 22 ++ .../nodes/gather_i32_3d_axis2/input_0.cairo | 41 ++++ .../nodes/gather_i32_3d_axis2/input_1.cairo | 18 ++ .../nodes/gather_i32_3d_axis2/output_0.cairo | 69 +++++++ src/tests/nodes/gather_i32_3d_default.cairo | 22 ++ .../nodes/gather_i32_3d_default/input_0.cairo | 41 ++++ .../nodes/gather_i32_3d_default/input_1.cairo | 18 ++ .../gather_i32_3d_default/output_0.cairo | 69 +++++++ src/tests/nodes/gather_i8_3d_axis1.cairo | 22 ++ .../nodes/gather_i8_3d_axis1/input_0.cairo | 41 ++++ .../nodes/gather_i8_3d_axis1/input_1.cairo | 18 ++ .../nodes/gather_i8_3d_axis1/output_0.cairo | 69 +++++++ src/tests/nodes/gather_i8_3d_axis2.cairo | 22 ++ .../nodes/gather_i8_3d_axis2/input_0.cairo | 41 ++++ .../nodes/gather_i8_3d_axis2/input_1.cairo | 18 ++ .../nodes/gather_i8_3d_axis2/output_0.cairo | 69 +++++++ src/tests/nodes/gather_i8_3d_default.cairo | 22 ++ .../nodes/gather_i8_3d_default/input_0.cairo | 41 ++++ .../nodes/gather_i8_3d_default/input_1.cairo | 18 ++ .../nodes/gather_i8_3d_default/output_0.cairo | 69 +++++++ src/tests/nodes/gather_u32_3d_axis1.cairo | 22 ++ .../nodes/gather_u32_3d_axis1/input_0.cairo | 49 +++++ .../nodes/gather_u32_3d_axis1/input_1.cairo | 18 ++ .../nodes/gather_u32_3d_axis1/output_0.cairo | 68 ++++++ src/tests/nodes/gather_u32_3d_axis2.cairo | 22 ++ .../nodes/gather_u32_3d_axis2/input_0.cairo | 49 +++++ .../nodes/gather_u32_3d_axis2/input_1.cairo | 18 ++ .../nodes/gather_u32_3d_axis2/output_0.cairo | 86 ++++++++ src/tests/nodes/gather_u32_3d_default.cairo | 22 ++ .../nodes/gather_u32_3d_default/input_0.cairo | 49 +++++ .../nodes/gather_u32_3d_default/input_1.cairo | 18 ++ .../gather_u32_3d_default/output_0.cairo | 86 ++++++++ 53 files changed, 1666 insertions(+), 113 deletions(-) create mode 100644 src/tests/nodes/gather_i32_3d_axis1.cairo create mode 100644 src/tests/nodes/gather_i32_3d_axis1/input_0.cairo create mode 100644 src/tests/nodes/gather_i32_3d_axis1/input_1.cairo create mode 100644 src/tests/nodes/gather_i32_3d_axis1/output_0.cairo create mode 100644 src/tests/nodes/gather_i32_3d_axis2.cairo create mode 100644 src/tests/nodes/gather_i32_3d_axis2/input_0.cairo create mode 100644 src/tests/nodes/gather_i32_3d_axis2/input_1.cairo create mode 100644 src/tests/nodes/gather_i32_3d_axis2/output_0.cairo create mode 100644 src/tests/nodes/gather_i32_3d_default.cairo create mode 100644 src/tests/nodes/gather_i32_3d_default/input_0.cairo create mode 100644 src/tests/nodes/gather_i32_3d_default/input_1.cairo create mode 100644 src/tests/nodes/gather_i32_3d_default/output_0.cairo create mode 100644 src/tests/nodes/gather_i8_3d_axis1.cairo create mode 100644 src/tests/nodes/gather_i8_3d_axis1/input_0.cairo create mode 100644 src/tests/nodes/gather_i8_3d_axis1/input_1.cairo create mode 100644 src/tests/nodes/gather_i8_3d_axis1/output_0.cairo create mode 100644 src/tests/nodes/gather_i8_3d_axis2.cairo create mode 100644 src/tests/nodes/gather_i8_3d_axis2/input_0.cairo create mode 100644 src/tests/nodes/gather_i8_3d_axis2/input_1.cairo create mode 100644 src/tests/nodes/gather_i8_3d_axis2/output_0.cairo create mode 100644 src/tests/nodes/gather_i8_3d_default.cairo create mode 100644 src/tests/nodes/gather_i8_3d_default/input_0.cairo create mode 100644 src/tests/nodes/gather_i8_3d_default/input_1.cairo create mode 100644 src/tests/nodes/gather_i8_3d_default/output_0.cairo create mode 100644 src/tests/nodes/gather_u32_3d_axis1.cairo create mode 100644 src/tests/nodes/gather_u32_3d_axis1/input_0.cairo create mode 100644 src/tests/nodes/gather_u32_3d_axis1/input_1.cairo create mode 100644 src/tests/nodes/gather_u32_3d_axis1/output_0.cairo create mode 100644 src/tests/nodes/gather_u32_3d_axis2.cairo create mode 100644 src/tests/nodes/gather_u32_3d_axis2/input_0.cairo create mode 100644 src/tests/nodes/gather_u32_3d_axis2/input_1.cairo create mode 100644 src/tests/nodes/gather_u32_3d_axis2/output_0.cairo create mode 100644 src/tests/nodes/gather_u32_3d_default.cairo create mode 100644 src/tests/nodes/gather_u32_3d_default/input_0.cairo create mode 100644 src/tests/nodes/gather_u32_3d_default/input_1.cairo create mode 100644 src/tests/nodes/gather_u32_3d_default/output_0.cairo diff --git a/nodegen/node/gather.py b/nodegen/node/gather.py index c06826f75..689c00943 100644 --- a/nodegen/node/gather.py +++ b/nodegen/node/gather.py @@ -3,18 +3,18 @@ from ..helpers import make_node, make_test, to_fp, Tensor, Dtype, FixedImpl, Trait class Gather(RunAll): + @staticmethod def gather_fp16x16(): def gather_3D(): def default(): x1 = np.arange(0,27).reshape(3,3,3).astype(np.int64) - x2 = np.array([[0,1], [2,1], [0, 2]]).astype(np.int64) + x2 = np.array([[0,1], [2,1], [0, 2]]).astype(np.uint32) y = x1.take(x2, axis=0) x1 = Tensor(Dtype.FP16x16, x1.shape, to_fp(x1.flatten(), FixedImpl.FP16x16)) - x2 = Tensor(Dtype.FP16x16, x2.shape, to_fp( - x2.flatten(), FixedImpl.FP16x16)) + x2 = Tensor(Dtype.U32, x2.shape, x2.flatten()) y = Tensor(Dtype.FP16x16, y.shape, to_fp( y.flatten(), FixedImpl.FP16x16)) @@ -30,8 +30,7 @@ def axis1(): y = x1.take(x2, axis=1) x1 = Tensor(Dtype.FP16x16, x1.shape, to_fp(x1.flatten(), FixedImpl.FP16x16)) - x2 = Tensor(Dtype.FP16x16, x2.shape, to_fp( - x2.flatten(), FixedImpl.FP16x16)) + x2 = Tensor(Dtype.U32, x2.shape, x2.flatten()) y = Tensor(Dtype.FP16x16, y.shape, to_fp( y.flatten(), FixedImpl.FP16x16)) @@ -47,8 +46,7 @@ def axis2(): y = x1.take(x2, axis=2) x1 = Tensor(Dtype.FP16x16, x1.shape, to_fp(x1.flatten(), FixedImpl.FP16x16)) - x2 = Tensor(Dtype.FP16x16, x2.shape, to_fp( - x2.flatten(), FixedImpl.FP16x16)) + x2 = Tensor(Dtype.U32, x2.shape, x2.flatten()) y = Tensor(Dtype.FP16x16, y.shape, to_fp( y.flatten(), FixedImpl.FP16x16)) @@ -62,7 +60,7 @@ def axis2(): axis1() axis2() gather_3D() - + @staticmethod def gather_fp8x23(): @@ -73,10 +71,8 @@ def default(): y = x1.take(x2, axis=0) x1 = Tensor(Dtype.FP8x23, x1.shape, to_fp(x1.flatten(), FixedImpl.FP8x23)) - x2 = Tensor(Dtype.FP8x23, x2.shape, to_fp( - x2.flatten(), FixedImpl.FP8x23)) - y = Tensor(Dtype.FP8x23, y.shape, to_fp( - y.flatten(), FixedImpl.FP8x23)) + x2 = Tensor(Dtype.U32, x2.shape, x2.flatten()) + y = Tensor(Dtype.FP8x23, y.shape, to_fp(y.flatten(), FixedImpl.FP8x23)) name = "gather_fp8x23_3d_default" make_node([x1, x2], [y], name) @@ -90,10 +86,8 @@ def axis1(): y = x1.take(x2, axis=1) x1 = Tensor(Dtype.FP8x23, x1.shape, to_fp(x1.flatten(), FixedImpl.FP8x23)) - x2 = Tensor(Dtype.FP8x23, x2.shape, to_fp( - x2.flatten(), FixedImpl.FP8x23)) - y = Tensor(Dtype.FP8x23, y.shape, to_fp( - y.flatten(), FixedImpl.FP8x23)) + x2 = Tensor(Dtype.U32, x2.shape, x2.flatten()) + y = Tensor(Dtype.FP8x23, y.shape, to_fp(y.flatten(), FixedImpl.FP8x23)) name = "gather_fp8x23_3d_axis1" make_node([x1, x2], [y], name) @@ -107,10 +101,8 @@ def axis2(): y = x1.take(x2, axis=2) x1 = Tensor(Dtype.FP8x23, x1.shape, to_fp(x1.flatten(), FixedImpl.FP8x23)) - x2 = Tensor(Dtype.FP8x23, x2.shape, to_fp( - x2.flatten(), FixedImpl.FP8x23)) - y = Tensor(Dtype.FP8x23, y.shape, to_fp( - y.flatten(), FixedImpl.FP8x23)) + x2 = Tensor(Dtype.U32, x2.shape, x2.flatten()) + y = Tensor(Dtype.FP8x23, y.shape, to_fp(y.flatten(), FixedImpl.FP8x23)) name = "gather_fp8x23_3d_axis2" make_node([x1, x2], [y], name) @@ -118,6 +110,169 @@ def axis2(): inputs = [x1, x2], output = y, func_sig = "input_0.gather(indices:input_1, axis:Option::Some(2))", file_name= name) + default() + axis1() + axis2() + gather_3D() + + @staticmethod + def gather_i8(): + + def gather_3D(): + def default(): + x1 = np.arange(0,27).reshape(3,3,3).astype(np.int8) + x2 = np.array([[0,1], [2,1], [0, 2]]).astype(np.int8) + y = x1.take(x2, axis=0) + + x1 = Tensor(Dtype.I8, x1.shape, x1.flatten()) + x2 = Tensor(Dtype.U32, x2.shape, x2.flatten()) + y = Tensor(Dtype.I8, y.shape, y.flatten()) + + name = "gather_i8_3d_default" + make_node([x1, x2], [y], name) + make_test( + inputs = [x1, x2], output = y, func_sig = "input_0.gather(indices:input_1, axis:Option::Some(0))", + file_name= name) + + def axis1(): + x1 = np.arange(0,27).reshape(3,3,3).astype(np.int8) + x2 = np.array([[0,1], [2,1], [0, 2]]).astype(np.int8) + y = x1.take(x2, axis=1) + + x1 = Tensor(Dtype.I8, x1.shape, x1.flatten()) + x2 = Tensor(Dtype.U32, x2.shape, x2.flatten()) + y = Tensor(Dtype.I8, y.shape, y.flatten()) + + name = "gather_i8_3d_axis1" + make_node([x1, x2], [y], name) + make_test( + inputs = [x1, x2], output = y, func_sig = "input_0.gather(indices:input_1, axis:Option::Some(1))", + file_name= name) + + def axis2(): + x1 = np.arange(0,27).reshape(3,3,3).astype(np.int8) + x2 = np.array([[0,1], [2,1], [0, 2]]).astype(np.int8) + y = x1.take(x2, axis=2) + + x1 = Tensor(Dtype.I8, x1.shape, x1.flatten()) + x2 = Tensor(Dtype.U32, x2.shape, x2.flatten()) + y = Tensor(Dtype.I8, y.shape, y.flatten()) + + name = "gather_i8_3d_axis2" + make_node([x1, x2], [y], name) + make_test( + inputs = [x1, x2], output = y, func_sig = "input_0.gather(indices:input_1, axis:Option::Some(2))", + file_name= name) + + default() + axis1() + axis2() + gather_3D() + + + @staticmethod + def gather_i32(): + + def gather_3D(): + def default(): + x1 = np.arange(0,27).reshape(3,3,3).astype(np.int32) + x2 = np.array([[0,1], [2,1], [0, 2]]).astype(np.int32) + y = x1.take(x2, axis=0) + + x1 = Tensor(Dtype.I32, x1.shape, x1.flatten()) + x2 = Tensor(Dtype.U32, x2.shape, x2.flatten()) + y = Tensor(Dtype.I32, y.shape, y.flatten()) + + name = "gather_i32_3d_default" + make_node([x1, x2], [y], name) + make_test( + inputs = [x1, x2], output = y, func_sig = "input_0.gather(indices:input_1, axis:Option::Some(0))", + file_name= name) + + def axis1(): + x1 = np.arange(0,27).reshape(3,3,3).astype(np.int32) + x2 = np.array([[0,1], [2,1], [0, 2]]).astype(np.int32) + y = x1.take(x2, axis=1) + + x1 = Tensor(Dtype.I32, x1.shape, x1.flatten()) + x2 = Tensor(Dtype.U32, x2.shape, x2.flatten()) + y = Tensor(Dtype.I32, y.shape, y.flatten()) + + name = "gather_i32_3d_axis1" + make_node([x1, x2], [y], name) + make_test( + inputs = [x1, x2], output = y, func_sig = "input_0.gather(indices:input_1, axis:Option::Some(1))", + file_name= name) + + def axis2(): + x1 = np.arange(0,27).reshape(3,3,3).astype(np.int32) + x2 = np.array([[0,1], [2,1], [0, 2]]).astype(np.int32) + y = x1.take(x2, axis=2) + + x1 = Tensor(Dtype.I32, x1.shape, x1.flatten()) + x2 = Tensor(Dtype.U32, x2.shape, x2.flatten()) + y = Tensor(Dtype.I32, y.shape, y.flatten()) + + name = "gather_i32_3d_axis2" + make_node([x1, x2], [y], name) + make_test( + inputs = [x1, x2], output = y, func_sig = "input_0.gather(indices:input_1, axis:Option::Some(2))", + file_name= name) + + default() + axis1() + axis2() + gather_3D() + + @staticmethod + def gather_u32(): + + def gather_3D(): + def default(): + x1 = np.arange(0,36).reshape(3,4,3).astype(np.uint32) + x2 = np.array([[0,1], [2,1], [0, 2]]).astype(np.uint32) + y = x1.take(x2, axis=0) + + x1 = Tensor(Dtype.U32, x1.shape, x1.flatten()) + x2 = Tensor(Dtype.U32, x2.shape, x2.flatten()) + y = Tensor(Dtype.U32, y.shape, y.flatten()) + + name = "gather_u32_3d_default" + make_node([x1, x2], [y], name) + make_test( + inputs = [x1, x2], output = y, func_sig = "input_0.gather(indices:input_1, axis:Option::Some(0))", + file_name= name) + + def axis1(): + x1 = np.arange(0,36).reshape(3,4,3).astype(np.uint32) + x2 = np.array([[0,1], [2,1], [1, 3]]).astype(np.uint32) + y = x1.take(x2, axis=1) + + x1 = Tensor(Dtype.U32, x1.shape, x1.flatten()) + x2 = Tensor(Dtype.U32, x2.shape, x2.flatten()) + y = Tensor(Dtype.U32, y.shape, y.flatten()) + + name = "gather_u32_3d_axis1" + make_node([x1, x2], [y], name) + make_test( + inputs = [x1, x2], output = y, func_sig = "input_0.gather(indices:input_1, axis:Option::Some(1))", + file_name= name) + + def axis2(): + x1 = np.arange(0,36).reshape(3,4,3).astype(np.uint32) + x2 = np.array([[0,1], [2,1], [1, 2]]).astype(np.uint32) + y = x1.take(x2, axis=2) + + x1 = Tensor(Dtype.U32, x1.shape, x1.flatten()) + x2 = Tensor(Dtype.U32, x2.shape, x2.flatten()) + y = Tensor(Dtype.U32, y.shape, y.flatten()) + + name = "gather_u32_3d_axis2" + make_node([x1, x2], [y], name) + make_test( + inputs = [x1, x2], output = y, func_sig = "input_0.gather(indices:input_1, axis:Option::Some(2))", + file_name= name) + default() axis1() axis2() diff --git a/src/operators/tensor/core.cairo b/src/operators/tensor/core.cairo index 9316d8afd..934403e73 100644 --- a/src/operators/tensor/core.cairo +++ b/src/operators/tensor/core.cairo @@ -2466,7 +2466,7 @@ trait TensorTrait { /// ``` /// fn gather( - self: @Tensor, indices: Tensor, axis: Option + self: @Tensor, indices: Tensor, axis: Option ) -> Tensor ; } diff --git a/src/operators/tensor/implementations/tensor_fp16x16.cairo b/src/operators/tensor/implementations/tensor_fp16x16.cairo index a0553d523..35f34e8ea 100644 --- a/src/operators/tensor/implementations/tensor_fp16x16.cairo +++ b/src/operators/tensor/implementations/tensor_fp16x16.cairo @@ -205,7 +205,7 @@ impl FP16x16Tensor of TensorTrait { } fn gather( - self: @Tensor, indices: Tensor, axis: Option + self: @Tensor, indices: Tensor, axis: Option ) -> Tensor { math::gather::gather(self, indices, axis) } diff --git a/src/operators/tensor/implementations/tensor_fp32x32.cairo b/src/operators/tensor/implementations/tensor_fp32x32.cairo index 7b47b7645..06689dc19 100644 --- a/src/operators/tensor/implementations/tensor_fp32x32.cairo +++ b/src/operators/tensor/implementations/tensor_fp32x32.cairo @@ -206,7 +206,7 @@ impl FP32x32Tensor of TensorTrait { } fn gather( - self: @Tensor, indices: Tensor, axis: Option + self: @Tensor, indices: Tensor, axis: Option ) -> Tensor { math::gather::gather(self, indices, axis) } diff --git a/src/operators/tensor/implementations/tensor_fp64x64.cairo b/src/operators/tensor/implementations/tensor_fp64x64.cairo index 396b7e909..ee63eff60 100644 --- a/src/operators/tensor/implementations/tensor_fp64x64.cairo +++ b/src/operators/tensor/implementations/tensor_fp64x64.cairo @@ -206,7 +206,7 @@ impl FP64x64Tensor of TensorTrait { } fn gather( - self: @Tensor, indices: Tensor, axis: Option + self: @Tensor, indices: Tensor, axis: Option ) -> Tensor { math::gather::gather(self, indices, axis) } diff --git a/src/operators/tensor/implementations/tensor_fp8x23.cairo b/src/operators/tensor/implementations/tensor_fp8x23.cairo index 79ead3b12..bdf0704c1 100644 --- a/src/operators/tensor/implementations/tensor_fp8x23.cairo +++ b/src/operators/tensor/implementations/tensor_fp8x23.cairo @@ -205,7 +205,7 @@ impl FP8x23Tensor of TensorTrait { } fn gather( - self: @Tensor, indices: Tensor, axis: Option + self: @Tensor, indices: Tensor, axis: Option ) -> Tensor { math::gather::gather(self, indices, axis) } diff --git a/src/operators/tensor/implementations/tensor_i32.cairo b/src/operators/tensor/implementations/tensor_i32.cairo index 592941316..e9ed69c3c 100644 --- a/src/operators/tensor/implementations/tensor_i32.cairo +++ b/src/operators/tensor/implementations/tensor_i32.cairo @@ -206,10 +206,9 @@ impl I32Tensor of TensorTrait { } fn gather( - self: @Tensor, indices: Tensor, axis: Option + self: @Tensor, indices: Tensor, axis: Option ) -> Tensor { - // math::gather::gather(self, indices, axis) - panic(array!['not supported!']) + math::gather::gather(self, indices, axis) } } diff --git a/src/operators/tensor/implementations/tensor_i8.cairo b/src/operators/tensor/implementations/tensor_i8.cairo index c495be935..dab291408 100644 --- a/src/operators/tensor/implementations/tensor_i8.cairo +++ b/src/operators/tensor/implementations/tensor_i8.cairo @@ -205,10 +205,9 @@ impl I8Tensor of TensorTrait { } fn gather( - self: @Tensor, indices: Tensor, axis: Option + self: @Tensor, indices: Tensor, axis: Option ) -> Tensor { - // math::gather::gather(self, indices, axis) - panic(array!['not supported!']) + math::gather::gather(self, indices, axis) } } diff --git a/src/operators/tensor/implementations/tensor_u32.cairo b/src/operators/tensor/implementations/tensor_u32.cairo index a5af503ea..3f22dec21 100644 --- a/src/operators/tensor/implementations/tensor_u32.cairo +++ b/src/operators/tensor/implementations/tensor_u32.cairo @@ -199,10 +199,9 @@ impl U32Tensor of TensorTrait { } fn gather( - self: @Tensor, indices: Tensor, axis: Option - ) -> Tensor { - // math::gather::gather(self, indices, axis) - panic(array!['not supported!']) + self: @Tensor, indices: Tensor, axis: Option + ) -> Tensor { + math::gather::gather(self, indices, axis) } } diff --git a/src/operators/tensor/math/gather.cairo b/src/operators/tensor/math/gather.cairo index 568c3d761..ae2523abe 100644 --- a/src/operators/tensor/math/gather.cairo +++ b/src/operators/tensor/math/gather.cairo @@ -1,3 +1,4 @@ +use alexandria_data_structures::array_ext::SpanTraitExt; use array::ArrayTrait; use array::SpanTrait; @@ -9,24 +10,16 @@ use core::traits::Destruct; use option::OptionTrait; use orion::numbers::NumberTrait; -use orion::numbers::fixed_point::core::FixedTrait; use orion::operators::tensor::{TensorTrait, Tensor}; /// Cf: TensorTrait::gather docstring fn gather< - T, - MAG, - impl FFixed: FixedTrait, - impl FTensorTrait: TensorTrait, - impl FNumber: NumberTrait, - impl U32TryIntoMAG: TryInto, - impl FPartialEq: PartialEq, - impl FPartialOrd: PartialOrd, - impl FAdd: Add, - impl FCopy: Copy, - impl FDrop: Drop, + T, + impl TTensorTrait: TensorTrait, + impl TCopy: Copy, + impl TDrop: Drop, >( - self: @Tensor, indices: Tensor, axis: Option + self: @Tensor, indices: Tensor, axis: Option ) -> Tensor { let data = *self.data; let shape = *self.shape; @@ -41,7 +34,9 @@ fn gather< let rank_indices = indices.shape.len(); let data_indices = indices.data; let axis_shape = *shape.at(axis); - assert(indices.max() < FixedTrait::::new_unscaled(axis_shape.try_into().unwrap(), false), 'this index out of bounds'); + let ind_max = data_indices.max().unwrap(); + assert(ind_max < axis_shape, 'this index out of bounds'); + @@ -115,7 +110,8 @@ fn gather< }; let new_val = inner_loop / divisor % *shape.at(axis); - if (FixedTrait::::new_unscaled(new_val.try_into().unwrap(), false) == indice) { + if (indice == new_val) { + let val = break_loop * outer_loop + inner_loop; let data_val = *data.at(val); output_data.append(data_val); diff --git a/src/tests/nodes.cairo b/src/tests/nodes.cairo index 276bc9fcd..94902a68c 100644 --- a/src/tests/nodes.cairo +++ b/src/tests/nodes.cairo @@ -377,9 +377,19 @@ mod slice_i32_3d; mod slice_i8_2d; mod slice_i8_3d; mod slice_u32_2d; -mod slice_u32_3d; mod gather_fp16x16_3d_default; -mod gather_fp16x16_3d_axis1; -mod gather_fp16x16_3d_axis2; +mod slice_u32_3d; mod gather_fp8x23_3d_default; mod gather_fp8x23_3d_axis1; mod gather_fp8x23_3d_axis2; +mod gather_fp16x16_3d_default; +mod gather_fp16x16_3d_axis1; +mod gather_fp16x16_3d_axis2; +mod gather_i8_3d_default; +mod gather_i8_3d_axis1; +mod gather_i8_3d_axis2; +mod gather_i32_3d_default; +mod gather_i32_3d_axis1; +mod gather_i32_3d_axis2; +mod gather_u32_3d_default; +mod gather_u32_3d_axis1; +mod gather_u32_3d_axis2; diff --git a/src/tests/nodes/gather_fp16x16_3d_axis1/input_1.cairo b/src/tests/nodes/gather_fp16x16_3d_axis1/input_1.cairo index 29b3f7b03..376bc51e6 100644 --- a/src/tests/nodes/gather_fp16x16_3d_axis1/input_1.cairo +++ b/src/tests/nodes/gather_fp16x16_3d_axis1/input_1.cairo @@ -1,20 +1,18 @@ use array::{ArrayTrait, SpanTrait}; use orion::operators::tensor::{TensorTrait, Tensor}; -use orion::operators::tensor::FP16x16Tensor; -use orion::numbers::FixedTrait; -use orion::numbers::FP16x16; +use orion::operators::tensor::U32Tensor; -fn input_1() -> Tensor { +fn input_1() -> Tensor { let mut shape = ArrayTrait::::new(); shape.append(3); shape.append(2); let mut data = ArrayTrait::new(); - data.append(FP16x16 { mag: 0, sign: false }); - data.append(FP16x16 { mag: 65536, sign: false }); - data.append(FP16x16 { mag: 131072, sign: false }); - data.append(FP16x16 { mag: 65536, sign: false }); - data.append(FP16x16 { mag: 0, sign: false }); - data.append(FP16x16 { mag: 131072, sign: false }); + data.append(0); + data.append(1); + data.append(2); + data.append(1); + data.append(0); + data.append(2); TensorTrait::new(shape.span(), data.span()) } \ No newline at end of file diff --git a/src/tests/nodes/gather_fp16x16_3d_axis2/input_1.cairo b/src/tests/nodes/gather_fp16x16_3d_axis2/input_1.cairo index 29b3f7b03..376bc51e6 100644 --- a/src/tests/nodes/gather_fp16x16_3d_axis2/input_1.cairo +++ b/src/tests/nodes/gather_fp16x16_3d_axis2/input_1.cairo @@ -1,20 +1,18 @@ use array::{ArrayTrait, SpanTrait}; use orion::operators::tensor::{TensorTrait, Tensor}; -use orion::operators::tensor::FP16x16Tensor; -use orion::numbers::FixedTrait; -use orion::numbers::FP16x16; +use orion::operators::tensor::U32Tensor; -fn input_1() -> Tensor { +fn input_1() -> Tensor { let mut shape = ArrayTrait::::new(); shape.append(3); shape.append(2); let mut data = ArrayTrait::new(); - data.append(FP16x16 { mag: 0, sign: false }); - data.append(FP16x16 { mag: 65536, sign: false }); - data.append(FP16x16 { mag: 131072, sign: false }); - data.append(FP16x16 { mag: 65536, sign: false }); - data.append(FP16x16 { mag: 0, sign: false }); - data.append(FP16x16 { mag: 131072, sign: false }); + data.append(0); + data.append(1); + data.append(2); + data.append(1); + data.append(0); + data.append(2); TensorTrait::new(shape.span(), data.span()) } \ No newline at end of file diff --git a/src/tests/nodes/gather_fp16x16_3d_default/input_1.cairo b/src/tests/nodes/gather_fp16x16_3d_default/input_1.cairo index 29b3f7b03..376bc51e6 100644 --- a/src/tests/nodes/gather_fp16x16_3d_default/input_1.cairo +++ b/src/tests/nodes/gather_fp16x16_3d_default/input_1.cairo @@ -1,20 +1,18 @@ use array::{ArrayTrait, SpanTrait}; use orion::operators::tensor::{TensorTrait, Tensor}; -use orion::operators::tensor::FP16x16Tensor; -use orion::numbers::FixedTrait; -use orion::numbers::FP16x16; +use orion::operators::tensor::U32Tensor; -fn input_1() -> Tensor { +fn input_1() -> Tensor { let mut shape = ArrayTrait::::new(); shape.append(3); shape.append(2); let mut data = ArrayTrait::new(); - data.append(FP16x16 { mag: 0, sign: false }); - data.append(FP16x16 { mag: 65536, sign: false }); - data.append(FP16x16 { mag: 131072, sign: false }); - data.append(FP16x16 { mag: 65536, sign: false }); - data.append(FP16x16 { mag: 0, sign: false }); - data.append(FP16x16 { mag: 131072, sign: false }); + data.append(0); + data.append(1); + data.append(2); + data.append(1); + data.append(0); + data.append(2); TensorTrait::new(shape.span(), data.span()) } \ No newline at end of file diff --git a/src/tests/nodes/gather_fp8x23_3d_axis1/input_1.cairo b/src/tests/nodes/gather_fp8x23_3d_axis1/input_1.cairo index 964aac912..376bc51e6 100644 --- a/src/tests/nodes/gather_fp8x23_3d_axis1/input_1.cairo +++ b/src/tests/nodes/gather_fp8x23_3d_axis1/input_1.cairo @@ -1,20 +1,18 @@ use array::{ArrayTrait, SpanTrait}; use orion::operators::tensor::{TensorTrait, Tensor}; -use orion::operators::tensor::FP8x23Tensor; -use orion::numbers::FixedTrait; -use orion::numbers::FP8x23; +use orion::operators::tensor::U32Tensor; -fn input_1() -> Tensor { +fn input_1() -> Tensor { let mut shape = ArrayTrait::::new(); shape.append(3); shape.append(2); let mut data = ArrayTrait::new(); - data.append(FP8x23 { mag: 0, sign: false }); - data.append(FP8x23 { mag: 8388608, sign: false }); - data.append(FP8x23 { mag: 16777216, sign: false }); - data.append(FP8x23 { mag: 8388608, sign: false }); - data.append(FP8x23 { mag: 0, sign: false }); - data.append(FP8x23 { mag: 16777216, sign: false }); + data.append(0); + data.append(1); + data.append(2); + data.append(1); + data.append(0); + data.append(2); TensorTrait::new(shape.span(), data.span()) } \ No newline at end of file diff --git a/src/tests/nodes/gather_fp8x23_3d_axis2/input_1.cairo b/src/tests/nodes/gather_fp8x23_3d_axis2/input_1.cairo index 964aac912..376bc51e6 100644 --- a/src/tests/nodes/gather_fp8x23_3d_axis2/input_1.cairo +++ b/src/tests/nodes/gather_fp8x23_3d_axis2/input_1.cairo @@ -1,20 +1,18 @@ use array::{ArrayTrait, SpanTrait}; use orion::operators::tensor::{TensorTrait, Tensor}; -use orion::operators::tensor::FP8x23Tensor; -use orion::numbers::FixedTrait; -use orion::numbers::FP8x23; +use orion::operators::tensor::U32Tensor; -fn input_1() -> Tensor { +fn input_1() -> Tensor { let mut shape = ArrayTrait::::new(); shape.append(3); shape.append(2); let mut data = ArrayTrait::new(); - data.append(FP8x23 { mag: 0, sign: false }); - data.append(FP8x23 { mag: 8388608, sign: false }); - data.append(FP8x23 { mag: 16777216, sign: false }); - data.append(FP8x23 { mag: 8388608, sign: false }); - data.append(FP8x23 { mag: 0, sign: false }); - data.append(FP8x23 { mag: 16777216, sign: false }); + data.append(0); + data.append(1); + data.append(2); + data.append(1); + data.append(0); + data.append(2); TensorTrait::new(shape.span(), data.span()) } \ No newline at end of file diff --git a/src/tests/nodes/gather_fp8x23_3d_default/input_1.cairo b/src/tests/nodes/gather_fp8x23_3d_default/input_1.cairo index 964aac912..376bc51e6 100644 --- a/src/tests/nodes/gather_fp8x23_3d_default/input_1.cairo +++ b/src/tests/nodes/gather_fp8x23_3d_default/input_1.cairo @@ -1,20 +1,18 @@ use array::{ArrayTrait, SpanTrait}; use orion::operators::tensor::{TensorTrait, Tensor}; -use orion::operators::tensor::FP8x23Tensor; -use orion::numbers::FixedTrait; -use orion::numbers::FP8x23; +use orion::operators::tensor::U32Tensor; -fn input_1() -> Tensor { +fn input_1() -> Tensor { let mut shape = ArrayTrait::::new(); shape.append(3); shape.append(2); let mut data = ArrayTrait::new(); - data.append(FP8x23 { mag: 0, sign: false }); - data.append(FP8x23 { mag: 8388608, sign: false }); - data.append(FP8x23 { mag: 16777216, sign: false }); - data.append(FP8x23 { mag: 8388608, sign: false }); - data.append(FP8x23 { mag: 0, sign: false }); - data.append(FP8x23 { mag: 16777216, sign: false }); + data.append(0); + data.append(1); + data.append(2); + data.append(1); + data.append(0); + data.append(2); TensorTrait::new(shape.span(), data.span()) } \ No newline at end of file diff --git a/src/tests/nodes/gather_i32_3d_axis1.cairo b/src/tests/nodes/gather_i32_3d_axis1.cairo new file mode 100644 index 000000000..a65dc9eca --- /dev/null +++ b/src/tests/nodes/gather_i32_3d_axis1.cairo @@ -0,0 +1,22 @@ +mod input_0; +mod input_1; +mod output_0; + + +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::TensorTrait; +use orion::operators::tensor::I32Tensor; +use orion::operators::tensor::I32TensorPartialEq; +use orion::utils::assert_eq; + +#[test] +#[available_gas(2000000000)] +fn test_gather_i32_3d_axis1() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let z = output_0::output_0(); + + let y = input_0.gather(indices:input_1, axis:Option::Some(1)); + + assert_eq(y, z); +} \ No newline at end of file diff --git a/src/tests/nodes/gather_i32_3d_axis1/input_0.cairo b/src/tests/nodes/gather_i32_3d_axis1/input_0.cairo new file mode 100644 index 000000000..aa97ad7e4 --- /dev/null +++ b/src/tests/nodes/gather_i32_3d_axis1/input_0.cairo @@ -0,0 +1,41 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::I32Tensor; +use orion::numbers::{IntegerTrait, i32}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(i32 { mag: 0, sign: false }); + data.append(i32 { mag: 1, sign: false }); + data.append(i32 { mag: 2, sign: false }); + data.append(i32 { mag: 3, sign: false }); + data.append(i32 { mag: 4, sign: false }); + data.append(i32 { mag: 5, sign: false }); + data.append(i32 { mag: 6, sign: false }); + data.append(i32 { mag: 7, sign: false }); + data.append(i32 { mag: 8, sign: false }); + data.append(i32 { mag: 9, sign: false }); + data.append(i32 { mag: 10, sign: false }); + data.append(i32 { mag: 11, sign: false }); + data.append(i32 { mag: 12, sign: false }); + data.append(i32 { mag: 13, sign: false }); + data.append(i32 { mag: 14, sign: false }); + data.append(i32 { mag: 15, sign: false }); + data.append(i32 { mag: 16, sign: false }); + data.append(i32 { mag: 17, sign: false }); + data.append(i32 { mag: 18, sign: false }); + data.append(i32 { mag: 19, sign: false }); + data.append(i32 { mag: 20, sign: false }); + data.append(i32 { mag: 21, sign: false }); + data.append(i32 { mag: 22, sign: false }); + data.append(i32 { mag: 23, sign: false }); + data.append(i32 { mag: 24, sign: false }); + data.append(i32 { mag: 25, sign: false }); + data.append(i32 { mag: 26, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/src/tests/nodes/gather_i32_3d_axis1/input_1.cairo b/src/tests/nodes/gather_i32_3d_axis1/input_1.cairo new file mode 100644 index 000000000..376bc51e6 --- /dev/null +++ b/src/tests/nodes/gather_i32_3d_axis1/input_1.cairo @@ -0,0 +1,18 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(0); + data.append(1); + data.append(2); + data.append(1); + data.append(0); + data.append(2); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/src/tests/nodes/gather_i32_3d_axis1/output_0.cairo b/src/tests/nodes/gather_i32_3d_axis1/output_0.cairo new file mode 100644 index 000000000..7ce2cdc57 --- /dev/null +++ b/src/tests/nodes/gather_i32_3d_axis1/output_0.cairo @@ -0,0 +1,69 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::I32Tensor; +use orion::numbers::{IntegerTrait, i32}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(3); + shape.append(2); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(i32 { mag: 0, sign: false }); + data.append(i32 { mag: 1, sign: false }); + data.append(i32 { mag: 2, sign: false }); + data.append(i32 { mag: 3, sign: false }); + data.append(i32 { mag: 4, sign: false }); + data.append(i32 { mag: 5, sign: false }); + data.append(i32 { mag: 6, sign: false }); + data.append(i32 { mag: 7, sign: false }); + data.append(i32 { mag: 8, sign: false }); + data.append(i32 { mag: 3, sign: false }); + data.append(i32 { mag: 4, sign: false }); + data.append(i32 { mag: 5, sign: false }); + data.append(i32 { mag: 0, sign: false }); + data.append(i32 { mag: 1, sign: false }); + data.append(i32 { mag: 2, sign: false }); + data.append(i32 { mag: 6, sign: false }); + data.append(i32 { mag: 7, sign: false }); + data.append(i32 { mag: 8, sign: false }); + data.append(i32 { mag: 9, sign: false }); + data.append(i32 { mag: 10, sign: false }); + data.append(i32 { mag: 11, sign: false }); + data.append(i32 { mag: 12, sign: false }); + data.append(i32 { mag: 13, sign: false }); + data.append(i32 { mag: 14, sign: false }); + data.append(i32 { mag: 15, sign: false }); + data.append(i32 { mag: 16, sign: false }); + data.append(i32 { mag: 17, sign: false }); + data.append(i32 { mag: 12, sign: false }); + data.append(i32 { mag: 13, sign: false }); + data.append(i32 { mag: 14, sign: false }); + data.append(i32 { mag: 9, sign: false }); + data.append(i32 { mag: 10, sign: false }); + data.append(i32 { mag: 11, sign: false }); + data.append(i32 { mag: 15, sign: false }); + data.append(i32 { mag: 16, sign: false }); + data.append(i32 { mag: 17, sign: false }); + data.append(i32 { mag: 18, sign: false }); + data.append(i32 { mag: 19, sign: false }); + data.append(i32 { mag: 20, sign: false }); + data.append(i32 { mag: 21, sign: false }); + data.append(i32 { mag: 22, sign: false }); + data.append(i32 { mag: 23, sign: false }); + data.append(i32 { mag: 24, sign: false }); + data.append(i32 { mag: 25, sign: false }); + data.append(i32 { mag: 26, sign: false }); + data.append(i32 { mag: 21, sign: false }); + data.append(i32 { mag: 22, sign: false }); + data.append(i32 { mag: 23, sign: false }); + data.append(i32 { mag: 18, sign: false }); + data.append(i32 { mag: 19, sign: false }); + data.append(i32 { mag: 20, sign: false }); + data.append(i32 { mag: 24, sign: false }); + data.append(i32 { mag: 25, sign: false }); + data.append(i32 { mag: 26, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/src/tests/nodes/gather_i32_3d_axis2.cairo b/src/tests/nodes/gather_i32_3d_axis2.cairo new file mode 100644 index 000000000..3352e0a93 --- /dev/null +++ b/src/tests/nodes/gather_i32_3d_axis2.cairo @@ -0,0 +1,22 @@ +mod input_0; +mod input_1; +mod output_0; + + +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::TensorTrait; +use orion::operators::tensor::I32Tensor; +use orion::operators::tensor::I32TensorPartialEq; +use orion::utils::assert_eq; + +#[test] +#[available_gas(2000000000)] +fn test_gather_i32_3d_axis2() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let z = output_0::output_0(); + + let y = input_0.gather(indices:input_1, axis:Option::Some(2)); + + assert_eq(y, z); +} \ No newline at end of file diff --git a/src/tests/nodes/gather_i32_3d_axis2/input_0.cairo b/src/tests/nodes/gather_i32_3d_axis2/input_0.cairo new file mode 100644 index 000000000..aa97ad7e4 --- /dev/null +++ b/src/tests/nodes/gather_i32_3d_axis2/input_0.cairo @@ -0,0 +1,41 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::I32Tensor; +use orion::numbers::{IntegerTrait, i32}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(i32 { mag: 0, sign: false }); + data.append(i32 { mag: 1, sign: false }); + data.append(i32 { mag: 2, sign: false }); + data.append(i32 { mag: 3, sign: false }); + data.append(i32 { mag: 4, sign: false }); + data.append(i32 { mag: 5, sign: false }); + data.append(i32 { mag: 6, sign: false }); + data.append(i32 { mag: 7, sign: false }); + data.append(i32 { mag: 8, sign: false }); + data.append(i32 { mag: 9, sign: false }); + data.append(i32 { mag: 10, sign: false }); + data.append(i32 { mag: 11, sign: false }); + data.append(i32 { mag: 12, sign: false }); + data.append(i32 { mag: 13, sign: false }); + data.append(i32 { mag: 14, sign: false }); + data.append(i32 { mag: 15, sign: false }); + data.append(i32 { mag: 16, sign: false }); + data.append(i32 { mag: 17, sign: false }); + data.append(i32 { mag: 18, sign: false }); + data.append(i32 { mag: 19, sign: false }); + data.append(i32 { mag: 20, sign: false }); + data.append(i32 { mag: 21, sign: false }); + data.append(i32 { mag: 22, sign: false }); + data.append(i32 { mag: 23, sign: false }); + data.append(i32 { mag: 24, sign: false }); + data.append(i32 { mag: 25, sign: false }); + data.append(i32 { mag: 26, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/src/tests/nodes/gather_i32_3d_axis2/input_1.cairo b/src/tests/nodes/gather_i32_3d_axis2/input_1.cairo new file mode 100644 index 000000000..376bc51e6 --- /dev/null +++ b/src/tests/nodes/gather_i32_3d_axis2/input_1.cairo @@ -0,0 +1,18 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(0); + data.append(1); + data.append(2); + data.append(1); + data.append(0); + data.append(2); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/src/tests/nodes/gather_i32_3d_axis2/output_0.cairo b/src/tests/nodes/gather_i32_3d_axis2/output_0.cairo new file mode 100644 index 000000000..4e61ad18d --- /dev/null +++ b/src/tests/nodes/gather_i32_3d_axis2/output_0.cairo @@ -0,0 +1,69 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::I32Tensor; +use orion::numbers::{IntegerTrait, i32}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(3); + shape.append(3); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(i32 { mag: 0, sign: false }); + data.append(i32 { mag: 1, sign: false }); + data.append(i32 { mag: 2, sign: false }); + data.append(i32 { mag: 1, sign: false }); + data.append(i32 { mag: 0, sign: false }); + data.append(i32 { mag: 2, sign: false }); + data.append(i32 { mag: 3, sign: false }); + data.append(i32 { mag: 4, sign: false }); + data.append(i32 { mag: 5, sign: false }); + data.append(i32 { mag: 4, sign: false }); + data.append(i32 { mag: 3, sign: false }); + data.append(i32 { mag: 5, sign: false }); + data.append(i32 { mag: 6, sign: false }); + data.append(i32 { mag: 7, sign: false }); + data.append(i32 { mag: 8, sign: false }); + data.append(i32 { mag: 7, sign: false }); + data.append(i32 { mag: 6, sign: false }); + data.append(i32 { mag: 8, sign: false }); + data.append(i32 { mag: 9, sign: false }); + data.append(i32 { mag: 10, sign: false }); + data.append(i32 { mag: 11, sign: false }); + data.append(i32 { mag: 10, sign: false }); + data.append(i32 { mag: 9, sign: false }); + data.append(i32 { mag: 11, sign: false }); + data.append(i32 { mag: 12, sign: false }); + data.append(i32 { mag: 13, sign: false }); + data.append(i32 { mag: 14, sign: false }); + data.append(i32 { mag: 13, sign: false }); + data.append(i32 { mag: 12, sign: false }); + data.append(i32 { mag: 14, sign: false }); + data.append(i32 { mag: 15, sign: false }); + data.append(i32 { mag: 16, sign: false }); + data.append(i32 { mag: 17, sign: false }); + data.append(i32 { mag: 16, sign: false }); + data.append(i32 { mag: 15, sign: false }); + data.append(i32 { mag: 17, sign: false }); + data.append(i32 { mag: 18, sign: false }); + data.append(i32 { mag: 19, sign: false }); + data.append(i32 { mag: 20, sign: false }); + data.append(i32 { mag: 19, sign: false }); + data.append(i32 { mag: 18, sign: false }); + data.append(i32 { mag: 20, sign: false }); + data.append(i32 { mag: 21, sign: false }); + data.append(i32 { mag: 22, sign: false }); + data.append(i32 { mag: 23, sign: false }); + data.append(i32 { mag: 22, sign: false }); + data.append(i32 { mag: 21, sign: false }); + data.append(i32 { mag: 23, sign: false }); + data.append(i32 { mag: 24, sign: false }); + data.append(i32 { mag: 25, sign: false }); + data.append(i32 { mag: 26, sign: false }); + data.append(i32 { mag: 25, sign: false }); + data.append(i32 { mag: 24, sign: false }); + data.append(i32 { mag: 26, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/src/tests/nodes/gather_i32_3d_default.cairo b/src/tests/nodes/gather_i32_3d_default.cairo new file mode 100644 index 000000000..385310ba7 --- /dev/null +++ b/src/tests/nodes/gather_i32_3d_default.cairo @@ -0,0 +1,22 @@ +mod input_0; +mod input_1; +mod output_0; + + +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::TensorTrait; +use orion::operators::tensor::I32Tensor; +use orion::operators::tensor::I32TensorPartialEq; +use orion::utils::assert_eq; + +#[test] +#[available_gas(2000000000)] +fn test_gather_i32_3d_default() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let z = output_0::output_0(); + + let y = input_0.gather(indices:input_1, axis:Option::Some(0)); + + assert_eq(y, z); +} \ No newline at end of file diff --git a/src/tests/nodes/gather_i32_3d_default/input_0.cairo b/src/tests/nodes/gather_i32_3d_default/input_0.cairo new file mode 100644 index 000000000..aa97ad7e4 --- /dev/null +++ b/src/tests/nodes/gather_i32_3d_default/input_0.cairo @@ -0,0 +1,41 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::I32Tensor; +use orion::numbers::{IntegerTrait, i32}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(i32 { mag: 0, sign: false }); + data.append(i32 { mag: 1, sign: false }); + data.append(i32 { mag: 2, sign: false }); + data.append(i32 { mag: 3, sign: false }); + data.append(i32 { mag: 4, sign: false }); + data.append(i32 { mag: 5, sign: false }); + data.append(i32 { mag: 6, sign: false }); + data.append(i32 { mag: 7, sign: false }); + data.append(i32 { mag: 8, sign: false }); + data.append(i32 { mag: 9, sign: false }); + data.append(i32 { mag: 10, sign: false }); + data.append(i32 { mag: 11, sign: false }); + data.append(i32 { mag: 12, sign: false }); + data.append(i32 { mag: 13, sign: false }); + data.append(i32 { mag: 14, sign: false }); + data.append(i32 { mag: 15, sign: false }); + data.append(i32 { mag: 16, sign: false }); + data.append(i32 { mag: 17, sign: false }); + data.append(i32 { mag: 18, sign: false }); + data.append(i32 { mag: 19, sign: false }); + data.append(i32 { mag: 20, sign: false }); + data.append(i32 { mag: 21, sign: false }); + data.append(i32 { mag: 22, sign: false }); + data.append(i32 { mag: 23, sign: false }); + data.append(i32 { mag: 24, sign: false }); + data.append(i32 { mag: 25, sign: false }); + data.append(i32 { mag: 26, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/src/tests/nodes/gather_i32_3d_default/input_1.cairo b/src/tests/nodes/gather_i32_3d_default/input_1.cairo new file mode 100644 index 000000000..376bc51e6 --- /dev/null +++ b/src/tests/nodes/gather_i32_3d_default/input_1.cairo @@ -0,0 +1,18 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(0); + data.append(1); + data.append(2); + data.append(1); + data.append(0); + data.append(2); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/src/tests/nodes/gather_i32_3d_default/output_0.cairo b/src/tests/nodes/gather_i32_3d_default/output_0.cairo new file mode 100644 index 000000000..38895e3c6 --- /dev/null +++ b/src/tests/nodes/gather_i32_3d_default/output_0.cairo @@ -0,0 +1,69 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::I32Tensor; +use orion::numbers::{IntegerTrait, i32}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(2); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(i32 { mag: 0, sign: false }); + data.append(i32 { mag: 1, sign: false }); + data.append(i32 { mag: 2, sign: false }); + data.append(i32 { mag: 3, sign: false }); + data.append(i32 { mag: 4, sign: false }); + data.append(i32 { mag: 5, sign: false }); + data.append(i32 { mag: 6, sign: false }); + data.append(i32 { mag: 7, sign: false }); + data.append(i32 { mag: 8, sign: false }); + data.append(i32 { mag: 9, sign: false }); + data.append(i32 { mag: 10, sign: false }); + data.append(i32 { mag: 11, sign: false }); + data.append(i32 { mag: 12, sign: false }); + data.append(i32 { mag: 13, sign: false }); + data.append(i32 { mag: 14, sign: false }); + data.append(i32 { mag: 15, sign: false }); + data.append(i32 { mag: 16, sign: false }); + data.append(i32 { mag: 17, sign: false }); + data.append(i32 { mag: 18, sign: false }); + data.append(i32 { mag: 19, sign: false }); + data.append(i32 { mag: 20, sign: false }); + data.append(i32 { mag: 21, sign: false }); + data.append(i32 { mag: 22, sign: false }); + data.append(i32 { mag: 23, sign: false }); + data.append(i32 { mag: 24, sign: false }); + data.append(i32 { mag: 25, sign: false }); + data.append(i32 { mag: 26, sign: false }); + data.append(i32 { mag: 9, sign: false }); + data.append(i32 { mag: 10, sign: false }); + data.append(i32 { mag: 11, sign: false }); + data.append(i32 { mag: 12, sign: false }); + data.append(i32 { mag: 13, sign: false }); + data.append(i32 { mag: 14, sign: false }); + data.append(i32 { mag: 15, sign: false }); + data.append(i32 { mag: 16, sign: false }); + data.append(i32 { mag: 17, sign: false }); + data.append(i32 { mag: 0, sign: false }); + data.append(i32 { mag: 1, sign: false }); + data.append(i32 { mag: 2, sign: false }); + data.append(i32 { mag: 3, sign: false }); + data.append(i32 { mag: 4, sign: false }); + data.append(i32 { mag: 5, sign: false }); + data.append(i32 { mag: 6, sign: false }); + data.append(i32 { mag: 7, sign: false }); + data.append(i32 { mag: 8, sign: false }); + data.append(i32 { mag: 18, sign: false }); + data.append(i32 { mag: 19, sign: false }); + data.append(i32 { mag: 20, sign: false }); + data.append(i32 { mag: 21, sign: false }); + data.append(i32 { mag: 22, sign: false }); + data.append(i32 { mag: 23, sign: false }); + data.append(i32 { mag: 24, sign: false }); + data.append(i32 { mag: 25, sign: false }); + data.append(i32 { mag: 26, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/src/tests/nodes/gather_i8_3d_axis1.cairo b/src/tests/nodes/gather_i8_3d_axis1.cairo new file mode 100644 index 000000000..68971a9c9 --- /dev/null +++ b/src/tests/nodes/gather_i8_3d_axis1.cairo @@ -0,0 +1,22 @@ +mod input_0; +mod input_1; +mod output_0; + + +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::TensorTrait; +use orion::operators::tensor::I8Tensor; +use orion::operators::tensor::I8TensorPartialEq; +use orion::utils::assert_eq; + +#[test] +#[available_gas(2000000000)] +fn test_gather_i8_3d_axis1() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let z = output_0::output_0(); + + let y = input_0.gather(indices:input_1, axis:Option::Some(1)); + + assert_eq(y, z); +} \ No newline at end of file diff --git a/src/tests/nodes/gather_i8_3d_axis1/input_0.cairo b/src/tests/nodes/gather_i8_3d_axis1/input_0.cairo new file mode 100644 index 000000000..7b851ec75 --- /dev/null +++ b/src/tests/nodes/gather_i8_3d_axis1/input_0.cairo @@ -0,0 +1,41 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::I8Tensor; +use orion::numbers::{IntegerTrait, i8}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(i8 { mag: 0, sign: false }); + data.append(i8 { mag: 1, sign: false }); + data.append(i8 { mag: 2, sign: false }); + data.append(i8 { mag: 3, sign: false }); + data.append(i8 { mag: 4, sign: false }); + data.append(i8 { mag: 5, sign: false }); + data.append(i8 { mag: 6, sign: false }); + data.append(i8 { mag: 7, sign: false }); + data.append(i8 { mag: 8, sign: false }); + data.append(i8 { mag: 9, sign: false }); + data.append(i8 { mag: 10, sign: false }); + data.append(i8 { mag: 11, sign: false }); + data.append(i8 { mag: 12, sign: false }); + data.append(i8 { mag: 13, sign: false }); + data.append(i8 { mag: 14, sign: false }); + data.append(i8 { mag: 15, sign: false }); + data.append(i8 { mag: 16, sign: false }); + data.append(i8 { mag: 17, sign: false }); + data.append(i8 { mag: 18, sign: false }); + data.append(i8 { mag: 19, sign: false }); + data.append(i8 { mag: 20, sign: false }); + data.append(i8 { mag: 21, sign: false }); + data.append(i8 { mag: 22, sign: false }); + data.append(i8 { mag: 23, sign: false }); + data.append(i8 { mag: 24, sign: false }); + data.append(i8 { mag: 25, sign: false }); + data.append(i8 { mag: 26, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/src/tests/nodes/gather_i8_3d_axis1/input_1.cairo b/src/tests/nodes/gather_i8_3d_axis1/input_1.cairo new file mode 100644 index 000000000..376bc51e6 --- /dev/null +++ b/src/tests/nodes/gather_i8_3d_axis1/input_1.cairo @@ -0,0 +1,18 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(0); + data.append(1); + data.append(2); + data.append(1); + data.append(0); + data.append(2); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/src/tests/nodes/gather_i8_3d_axis1/output_0.cairo b/src/tests/nodes/gather_i8_3d_axis1/output_0.cairo new file mode 100644 index 000000000..133276470 --- /dev/null +++ b/src/tests/nodes/gather_i8_3d_axis1/output_0.cairo @@ -0,0 +1,69 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::I8Tensor; +use orion::numbers::{IntegerTrait, i8}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(3); + shape.append(2); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(i8 { mag: 0, sign: false }); + data.append(i8 { mag: 1, sign: false }); + data.append(i8 { mag: 2, sign: false }); + data.append(i8 { mag: 3, sign: false }); + data.append(i8 { mag: 4, sign: false }); + data.append(i8 { mag: 5, sign: false }); + data.append(i8 { mag: 6, sign: false }); + data.append(i8 { mag: 7, sign: false }); + data.append(i8 { mag: 8, sign: false }); + data.append(i8 { mag: 3, sign: false }); + data.append(i8 { mag: 4, sign: false }); + data.append(i8 { mag: 5, sign: false }); + data.append(i8 { mag: 0, sign: false }); + data.append(i8 { mag: 1, sign: false }); + data.append(i8 { mag: 2, sign: false }); + data.append(i8 { mag: 6, sign: false }); + data.append(i8 { mag: 7, sign: false }); + data.append(i8 { mag: 8, sign: false }); + data.append(i8 { mag: 9, sign: false }); + data.append(i8 { mag: 10, sign: false }); + data.append(i8 { mag: 11, sign: false }); + data.append(i8 { mag: 12, sign: false }); + data.append(i8 { mag: 13, sign: false }); + data.append(i8 { mag: 14, sign: false }); + data.append(i8 { mag: 15, sign: false }); + data.append(i8 { mag: 16, sign: false }); + data.append(i8 { mag: 17, sign: false }); + data.append(i8 { mag: 12, sign: false }); + data.append(i8 { mag: 13, sign: false }); + data.append(i8 { mag: 14, sign: false }); + data.append(i8 { mag: 9, sign: false }); + data.append(i8 { mag: 10, sign: false }); + data.append(i8 { mag: 11, sign: false }); + data.append(i8 { mag: 15, sign: false }); + data.append(i8 { mag: 16, sign: false }); + data.append(i8 { mag: 17, sign: false }); + data.append(i8 { mag: 18, sign: false }); + data.append(i8 { mag: 19, sign: false }); + data.append(i8 { mag: 20, sign: false }); + data.append(i8 { mag: 21, sign: false }); + data.append(i8 { mag: 22, sign: false }); + data.append(i8 { mag: 23, sign: false }); + data.append(i8 { mag: 24, sign: false }); + data.append(i8 { mag: 25, sign: false }); + data.append(i8 { mag: 26, sign: false }); + data.append(i8 { mag: 21, sign: false }); + data.append(i8 { mag: 22, sign: false }); + data.append(i8 { mag: 23, sign: false }); + data.append(i8 { mag: 18, sign: false }); + data.append(i8 { mag: 19, sign: false }); + data.append(i8 { mag: 20, sign: false }); + data.append(i8 { mag: 24, sign: false }); + data.append(i8 { mag: 25, sign: false }); + data.append(i8 { mag: 26, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/src/tests/nodes/gather_i8_3d_axis2.cairo b/src/tests/nodes/gather_i8_3d_axis2.cairo new file mode 100644 index 000000000..5033ea844 --- /dev/null +++ b/src/tests/nodes/gather_i8_3d_axis2.cairo @@ -0,0 +1,22 @@ +mod input_0; +mod input_1; +mod output_0; + + +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::TensorTrait; +use orion::operators::tensor::I8Tensor; +use orion::operators::tensor::I8TensorPartialEq; +use orion::utils::assert_eq; + +#[test] +#[available_gas(2000000000)] +fn test_gather_i8_3d_axis2() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let z = output_0::output_0(); + + let y = input_0.gather(indices:input_1, axis:Option::Some(2)); + + assert_eq(y, z); +} \ No newline at end of file diff --git a/src/tests/nodes/gather_i8_3d_axis2/input_0.cairo b/src/tests/nodes/gather_i8_3d_axis2/input_0.cairo new file mode 100644 index 000000000..7b851ec75 --- /dev/null +++ b/src/tests/nodes/gather_i8_3d_axis2/input_0.cairo @@ -0,0 +1,41 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::I8Tensor; +use orion::numbers::{IntegerTrait, i8}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(i8 { mag: 0, sign: false }); + data.append(i8 { mag: 1, sign: false }); + data.append(i8 { mag: 2, sign: false }); + data.append(i8 { mag: 3, sign: false }); + data.append(i8 { mag: 4, sign: false }); + data.append(i8 { mag: 5, sign: false }); + data.append(i8 { mag: 6, sign: false }); + data.append(i8 { mag: 7, sign: false }); + data.append(i8 { mag: 8, sign: false }); + data.append(i8 { mag: 9, sign: false }); + data.append(i8 { mag: 10, sign: false }); + data.append(i8 { mag: 11, sign: false }); + data.append(i8 { mag: 12, sign: false }); + data.append(i8 { mag: 13, sign: false }); + data.append(i8 { mag: 14, sign: false }); + data.append(i8 { mag: 15, sign: false }); + data.append(i8 { mag: 16, sign: false }); + data.append(i8 { mag: 17, sign: false }); + data.append(i8 { mag: 18, sign: false }); + data.append(i8 { mag: 19, sign: false }); + data.append(i8 { mag: 20, sign: false }); + data.append(i8 { mag: 21, sign: false }); + data.append(i8 { mag: 22, sign: false }); + data.append(i8 { mag: 23, sign: false }); + data.append(i8 { mag: 24, sign: false }); + data.append(i8 { mag: 25, sign: false }); + data.append(i8 { mag: 26, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/src/tests/nodes/gather_i8_3d_axis2/input_1.cairo b/src/tests/nodes/gather_i8_3d_axis2/input_1.cairo new file mode 100644 index 000000000..376bc51e6 --- /dev/null +++ b/src/tests/nodes/gather_i8_3d_axis2/input_1.cairo @@ -0,0 +1,18 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(0); + data.append(1); + data.append(2); + data.append(1); + data.append(0); + data.append(2); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/src/tests/nodes/gather_i8_3d_axis2/output_0.cairo b/src/tests/nodes/gather_i8_3d_axis2/output_0.cairo new file mode 100644 index 000000000..57d540d62 --- /dev/null +++ b/src/tests/nodes/gather_i8_3d_axis2/output_0.cairo @@ -0,0 +1,69 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::I8Tensor; +use orion::numbers::{IntegerTrait, i8}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(3); + shape.append(3); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(i8 { mag: 0, sign: false }); + data.append(i8 { mag: 1, sign: false }); + data.append(i8 { mag: 2, sign: false }); + data.append(i8 { mag: 1, sign: false }); + data.append(i8 { mag: 0, sign: false }); + data.append(i8 { mag: 2, sign: false }); + data.append(i8 { mag: 3, sign: false }); + data.append(i8 { mag: 4, sign: false }); + data.append(i8 { mag: 5, sign: false }); + data.append(i8 { mag: 4, sign: false }); + data.append(i8 { mag: 3, sign: false }); + data.append(i8 { mag: 5, sign: false }); + data.append(i8 { mag: 6, sign: false }); + data.append(i8 { mag: 7, sign: false }); + data.append(i8 { mag: 8, sign: false }); + data.append(i8 { mag: 7, sign: false }); + data.append(i8 { mag: 6, sign: false }); + data.append(i8 { mag: 8, sign: false }); + data.append(i8 { mag: 9, sign: false }); + data.append(i8 { mag: 10, sign: false }); + data.append(i8 { mag: 11, sign: false }); + data.append(i8 { mag: 10, sign: false }); + data.append(i8 { mag: 9, sign: false }); + data.append(i8 { mag: 11, sign: false }); + data.append(i8 { mag: 12, sign: false }); + data.append(i8 { mag: 13, sign: false }); + data.append(i8 { mag: 14, sign: false }); + data.append(i8 { mag: 13, sign: false }); + data.append(i8 { mag: 12, sign: false }); + data.append(i8 { mag: 14, sign: false }); + data.append(i8 { mag: 15, sign: false }); + data.append(i8 { mag: 16, sign: false }); + data.append(i8 { mag: 17, sign: false }); + data.append(i8 { mag: 16, sign: false }); + data.append(i8 { mag: 15, sign: false }); + data.append(i8 { mag: 17, sign: false }); + data.append(i8 { mag: 18, sign: false }); + data.append(i8 { mag: 19, sign: false }); + data.append(i8 { mag: 20, sign: false }); + data.append(i8 { mag: 19, sign: false }); + data.append(i8 { mag: 18, sign: false }); + data.append(i8 { mag: 20, sign: false }); + data.append(i8 { mag: 21, sign: false }); + data.append(i8 { mag: 22, sign: false }); + data.append(i8 { mag: 23, sign: false }); + data.append(i8 { mag: 22, sign: false }); + data.append(i8 { mag: 21, sign: false }); + data.append(i8 { mag: 23, sign: false }); + data.append(i8 { mag: 24, sign: false }); + data.append(i8 { mag: 25, sign: false }); + data.append(i8 { mag: 26, sign: false }); + data.append(i8 { mag: 25, sign: false }); + data.append(i8 { mag: 24, sign: false }); + data.append(i8 { mag: 26, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/src/tests/nodes/gather_i8_3d_default.cairo b/src/tests/nodes/gather_i8_3d_default.cairo new file mode 100644 index 000000000..6b71f4977 --- /dev/null +++ b/src/tests/nodes/gather_i8_3d_default.cairo @@ -0,0 +1,22 @@ +mod input_0; +mod input_1; +mod output_0; + + +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::TensorTrait; +use orion::operators::tensor::I8Tensor; +use orion::operators::tensor::I8TensorPartialEq; +use orion::utils::assert_eq; + +#[test] +#[available_gas(2000000000)] +fn test_gather_i8_3d_default() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let z = output_0::output_0(); + + let y = input_0.gather(indices:input_1, axis:Option::Some(0)); + + assert_eq(y, z); +} \ No newline at end of file diff --git a/src/tests/nodes/gather_i8_3d_default/input_0.cairo b/src/tests/nodes/gather_i8_3d_default/input_0.cairo new file mode 100644 index 000000000..7b851ec75 --- /dev/null +++ b/src/tests/nodes/gather_i8_3d_default/input_0.cairo @@ -0,0 +1,41 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::I8Tensor; +use orion::numbers::{IntegerTrait, i8}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(i8 { mag: 0, sign: false }); + data.append(i8 { mag: 1, sign: false }); + data.append(i8 { mag: 2, sign: false }); + data.append(i8 { mag: 3, sign: false }); + data.append(i8 { mag: 4, sign: false }); + data.append(i8 { mag: 5, sign: false }); + data.append(i8 { mag: 6, sign: false }); + data.append(i8 { mag: 7, sign: false }); + data.append(i8 { mag: 8, sign: false }); + data.append(i8 { mag: 9, sign: false }); + data.append(i8 { mag: 10, sign: false }); + data.append(i8 { mag: 11, sign: false }); + data.append(i8 { mag: 12, sign: false }); + data.append(i8 { mag: 13, sign: false }); + data.append(i8 { mag: 14, sign: false }); + data.append(i8 { mag: 15, sign: false }); + data.append(i8 { mag: 16, sign: false }); + data.append(i8 { mag: 17, sign: false }); + data.append(i8 { mag: 18, sign: false }); + data.append(i8 { mag: 19, sign: false }); + data.append(i8 { mag: 20, sign: false }); + data.append(i8 { mag: 21, sign: false }); + data.append(i8 { mag: 22, sign: false }); + data.append(i8 { mag: 23, sign: false }); + data.append(i8 { mag: 24, sign: false }); + data.append(i8 { mag: 25, sign: false }); + data.append(i8 { mag: 26, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/src/tests/nodes/gather_i8_3d_default/input_1.cairo b/src/tests/nodes/gather_i8_3d_default/input_1.cairo new file mode 100644 index 000000000..376bc51e6 --- /dev/null +++ b/src/tests/nodes/gather_i8_3d_default/input_1.cairo @@ -0,0 +1,18 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(0); + data.append(1); + data.append(2); + data.append(1); + data.append(0); + data.append(2); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/src/tests/nodes/gather_i8_3d_default/output_0.cairo b/src/tests/nodes/gather_i8_3d_default/output_0.cairo new file mode 100644 index 000000000..26722e237 --- /dev/null +++ b/src/tests/nodes/gather_i8_3d_default/output_0.cairo @@ -0,0 +1,69 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::I8Tensor; +use orion::numbers::{IntegerTrait, i8}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(2); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(i8 { mag: 0, sign: false }); + data.append(i8 { mag: 1, sign: false }); + data.append(i8 { mag: 2, sign: false }); + data.append(i8 { mag: 3, sign: false }); + data.append(i8 { mag: 4, sign: false }); + data.append(i8 { mag: 5, sign: false }); + data.append(i8 { mag: 6, sign: false }); + data.append(i8 { mag: 7, sign: false }); + data.append(i8 { mag: 8, sign: false }); + data.append(i8 { mag: 9, sign: false }); + data.append(i8 { mag: 10, sign: false }); + data.append(i8 { mag: 11, sign: false }); + data.append(i8 { mag: 12, sign: false }); + data.append(i8 { mag: 13, sign: false }); + data.append(i8 { mag: 14, sign: false }); + data.append(i8 { mag: 15, sign: false }); + data.append(i8 { mag: 16, sign: false }); + data.append(i8 { mag: 17, sign: false }); + data.append(i8 { mag: 18, sign: false }); + data.append(i8 { mag: 19, sign: false }); + data.append(i8 { mag: 20, sign: false }); + data.append(i8 { mag: 21, sign: false }); + data.append(i8 { mag: 22, sign: false }); + data.append(i8 { mag: 23, sign: false }); + data.append(i8 { mag: 24, sign: false }); + data.append(i8 { mag: 25, sign: false }); + data.append(i8 { mag: 26, sign: false }); + data.append(i8 { mag: 9, sign: false }); + data.append(i8 { mag: 10, sign: false }); + data.append(i8 { mag: 11, sign: false }); + data.append(i8 { mag: 12, sign: false }); + data.append(i8 { mag: 13, sign: false }); + data.append(i8 { mag: 14, sign: false }); + data.append(i8 { mag: 15, sign: false }); + data.append(i8 { mag: 16, sign: false }); + data.append(i8 { mag: 17, sign: false }); + data.append(i8 { mag: 0, sign: false }); + data.append(i8 { mag: 1, sign: false }); + data.append(i8 { mag: 2, sign: false }); + data.append(i8 { mag: 3, sign: false }); + data.append(i8 { mag: 4, sign: false }); + data.append(i8 { mag: 5, sign: false }); + data.append(i8 { mag: 6, sign: false }); + data.append(i8 { mag: 7, sign: false }); + data.append(i8 { mag: 8, sign: false }); + data.append(i8 { mag: 18, sign: false }); + data.append(i8 { mag: 19, sign: false }); + data.append(i8 { mag: 20, sign: false }); + data.append(i8 { mag: 21, sign: false }); + data.append(i8 { mag: 22, sign: false }); + data.append(i8 { mag: 23, sign: false }); + data.append(i8 { mag: 24, sign: false }); + data.append(i8 { mag: 25, sign: false }); + data.append(i8 { mag: 26, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/src/tests/nodes/gather_u32_3d_axis1.cairo b/src/tests/nodes/gather_u32_3d_axis1.cairo new file mode 100644 index 000000000..f91b4307c --- /dev/null +++ b/src/tests/nodes/gather_u32_3d_axis1.cairo @@ -0,0 +1,22 @@ +mod input_0; +mod input_1; +mod output_0; + + +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::TensorTrait; +use orion::operators::tensor::U32Tensor; +use orion::operators::tensor::U32TensorPartialEq; +use orion::utils::assert_eq; + +#[test] +#[available_gas(2000000000)] +fn test_gather_u32_3d_axis1() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let z = output_0::output_0(); + + let y = input_0.gather(indices:input_1, axis:Option::Some(1)); + + assert_eq(y, z); +} \ No newline at end of file diff --git a/src/tests/nodes/gather_u32_3d_axis1/input_0.cairo b/src/tests/nodes/gather_u32_3d_axis1/input_0.cairo new file mode 100644 index 000000000..62f994c29 --- /dev/null +++ b/src/tests/nodes/gather_u32_3d_axis1/input_0.cairo @@ -0,0 +1,49 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(4); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(0); + data.append(1); + data.append(2); + data.append(3); + data.append(4); + data.append(5); + data.append(6); + data.append(7); + data.append(8); + data.append(9); + data.append(10); + data.append(11); + data.append(12); + data.append(13); + data.append(14); + data.append(15); + data.append(16); + data.append(17); + data.append(18); + data.append(19); + data.append(20); + data.append(21); + data.append(22); + data.append(23); + data.append(24); + data.append(25); + data.append(26); + data.append(27); + data.append(28); + data.append(29); + data.append(30); + data.append(31); + data.append(32); + data.append(33); + data.append(34); + data.append(35); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/src/tests/nodes/gather_u32_3d_axis1/input_1.cairo b/src/tests/nodes/gather_u32_3d_axis1/input_1.cairo new file mode 100644 index 000000000..a980a959a --- /dev/null +++ b/src/tests/nodes/gather_u32_3d_axis1/input_1.cairo @@ -0,0 +1,18 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(0); + data.append(1); + data.append(2); + data.append(1); + data.append(1); + data.append(3); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/src/tests/nodes/gather_u32_3d_axis1/output_0.cairo b/src/tests/nodes/gather_u32_3d_axis1/output_0.cairo new file mode 100644 index 000000000..6df19aa35 --- /dev/null +++ b/src/tests/nodes/gather_u32_3d_axis1/output_0.cairo @@ -0,0 +1,68 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(3); + shape.append(2); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(0); + data.append(1); + data.append(2); + data.append(3); + data.append(4); + data.append(5); + data.append(6); + data.append(7); + data.append(8); + data.append(3); + data.append(4); + data.append(5); + data.append(3); + data.append(4); + data.append(5); + data.append(9); + data.append(10); + data.append(11); + data.append(12); + data.append(13); + data.append(14); + data.append(15); + data.append(16); + data.append(17); + data.append(18); + data.append(19); + data.append(20); + data.append(15); + data.append(16); + data.append(17); + data.append(15); + data.append(16); + data.append(17); + data.append(21); + data.append(22); + data.append(23); + data.append(24); + data.append(25); + data.append(26); + data.append(27); + data.append(28); + data.append(29); + data.append(30); + data.append(31); + data.append(32); + data.append(27); + data.append(28); + data.append(29); + data.append(27); + data.append(28); + data.append(29); + data.append(33); + data.append(34); + data.append(35); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/src/tests/nodes/gather_u32_3d_axis2.cairo b/src/tests/nodes/gather_u32_3d_axis2.cairo new file mode 100644 index 000000000..77fc85c50 --- /dev/null +++ b/src/tests/nodes/gather_u32_3d_axis2.cairo @@ -0,0 +1,22 @@ +mod input_0; +mod input_1; +mod output_0; + + +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::TensorTrait; +use orion::operators::tensor::U32Tensor; +use orion::operators::tensor::U32TensorPartialEq; +use orion::utils::assert_eq; + +#[test] +#[available_gas(2000000000)] +fn test_gather_u32_3d_axis2() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let z = output_0::output_0(); + + let y = input_0.gather(indices:input_1, axis:Option::Some(2)); + + assert_eq(y, z); +} \ No newline at end of file diff --git a/src/tests/nodes/gather_u32_3d_axis2/input_0.cairo b/src/tests/nodes/gather_u32_3d_axis2/input_0.cairo new file mode 100644 index 000000000..62f994c29 --- /dev/null +++ b/src/tests/nodes/gather_u32_3d_axis2/input_0.cairo @@ -0,0 +1,49 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(4); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(0); + data.append(1); + data.append(2); + data.append(3); + data.append(4); + data.append(5); + data.append(6); + data.append(7); + data.append(8); + data.append(9); + data.append(10); + data.append(11); + data.append(12); + data.append(13); + data.append(14); + data.append(15); + data.append(16); + data.append(17); + data.append(18); + data.append(19); + data.append(20); + data.append(21); + data.append(22); + data.append(23); + data.append(24); + data.append(25); + data.append(26); + data.append(27); + data.append(28); + data.append(29); + data.append(30); + data.append(31); + data.append(32); + data.append(33); + data.append(34); + data.append(35); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/src/tests/nodes/gather_u32_3d_axis2/input_1.cairo b/src/tests/nodes/gather_u32_3d_axis2/input_1.cairo new file mode 100644 index 000000000..097a1fbd3 --- /dev/null +++ b/src/tests/nodes/gather_u32_3d_axis2/input_1.cairo @@ -0,0 +1,18 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(0); + data.append(1); + data.append(2); + data.append(1); + data.append(1); + data.append(2); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/src/tests/nodes/gather_u32_3d_axis2/output_0.cairo b/src/tests/nodes/gather_u32_3d_axis2/output_0.cairo new file mode 100644 index 000000000..2f70b0cda --- /dev/null +++ b/src/tests/nodes/gather_u32_3d_axis2/output_0.cairo @@ -0,0 +1,86 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(4); + shape.append(3); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(0); + data.append(1); + data.append(2); + data.append(1); + data.append(1); + data.append(2); + data.append(3); + data.append(4); + data.append(5); + data.append(4); + data.append(4); + data.append(5); + data.append(6); + data.append(7); + data.append(8); + data.append(7); + data.append(7); + data.append(8); + data.append(9); + data.append(10); + data.append(11); + data.append(10); + data.append(10); + data.append(11); + data.append(12); + data.append(13); + data.append(14); + data.append(13); + data.append(13); + data.append(14); + data.append(15); + data.append(16); + data.append(17); + data.append(16); + data.append(16); + data.append(17); + data.append(18); + data.append(19); + data.append(20); + data.append(19); + data.append(19); + data.append(20); + data.append(21); + data.append(22); + data.append(23); + data.append(22); + data.append(22); + data.append(23); + data.append(24); + data.append(25); + data.append(26); + data.append(25); + data.append(25); + data.append(26); + data.append(27); + data.append(28); + data.append(29); + data.append(28); + data.append(28); + data.append(29); + data.append(30); + data.append(31); + data.append(32); + data.append(31); + data.append(31); + data.append(32); + data.append(33); + data.append(34); + data.append(35); + data.append(34); + data.append(34); + data.append(35); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/src/tests/nodes/gather_u32_3d_default.cairo b/src/tests/nodes/gather_u32_3d_default.cairo new file mode 100644 index 000000000..3a5a918ed --- /dev/null +++ b/src/tests/nodes/gather_u32_3d_default.cairo @@ -0,0 +1,22 @@ +mod input_0; +mod input_1; +mod output_0; + + +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::TensorTrait; +use orion::operators::tensor::U32Tensor; +use orion::operators::tensor::U32TensorPartialEq; +use orion::utils::assert_eq; + +#[test] +#[available_gas(2000000000)] +fn test_gather_u32_3d_default() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let z = output_0::output_0(); + + let y = input_0.gather(indices:input_1, axis:Option::Some(0)); + + assert_eq(y, z); +} \ No newline at end of file diff --git a/src/tests/nodes/gather_u32_3d_default/input_0.cairo b/src/tests/nodes/gather_u32_3d_default/input_0.cairo new file mode 100644 index 000000000..62f994c29 --- /dev/null +++ b/src/tests/nodes/gather_u32_3d_default/input_0.cairo @@ -0,0 +1,49 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(4); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(0); + data.append(1); + data.append(2); + data.append(3); + data.append(4); + data.append(5); + data.append(6); + data.append(7); + data.append(8); + data.append(9); + data.append(10); + data.append(11); + data.append(12); + data.append(13); + data.append(14); + data.append(15); + data.append(16); + data.append(17); + data.append(18); + data.append(19); + data.append(20); + data.append(21); + data.append(22); + data.append(23); + data.append(24); + data.append(25); + data.append(26); + data.append(27); + data.append(28); + data.append(29); + data.append(30); + data.append(31); + data.append(32); + data.append(33); + data.append(34); + data.append(35); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/src/tests/nodes/gather_u32_3d_default/input_1.cairo b/src/tests/nodes/gather_u32_3d_default/input_1.cairo new file mode 100644 index 000000000..376bc51e6 --- /dev/null +++ b/src/tests/nodes/gather_u32_3d_default/input_1.cairo @@ -0,0 +1,18 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(0); + data.append(1); + data.append(2); + data.append(1); + data.append(0); + data.append(2); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file diff --git a/src/tests/nodes/gather_u32_3d_default/output_0.cairo b/src/tests/nodes/gather_u32_3d_default/output_0.cairo new file mode 100644 index 000000000..1d36ef20c --- /dev/null +++ b/src/tests/nodes/gather_u32_3d_default/output_0.cairo @@ -0,0 +1,86 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(2); + shape.append(4); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(0); + data.append(1); + data.append(2); + data.append(3); + data.append(4); + data.append(5); + data.append(6); + data.append(7); + data.append(8); + data.append(9); + data.append(10); + data.append(11); + data.append(12); + data.append(13); + data.append(14); + data.append(15); + data.append(16); + data.append(17); + data.append(18); + data.append(19); + data.append(20); + data.append(21); + data.append(22); + data.append(23); + data.append(24); + data.append(25); + data.append(26); + data.append(27); + data.append(28); + data.append(29); + data.append(30); + data.append(31); + data.append(32); + data.append(33); + data.append(34); + data.append(35); + data.append(12); + data.append(13); + data.append(14); + data.append(15); + data.append(16); + data.append(17); + data.append(18); + data.append(19); + data.append(20); + data.append(21); + data.append(22); + data.append(23); + data.append(0); + data.append(1); + data.append(2); + data.append(3); + data.append(4); + data.append(5); + data.append(6); + data.append(7); + data.append(8); + data.append(9); + data.append(10); + data.append(11); + data.append(24); + data.append(25); + data.append(26); + data.append(27); + data.append(28); + data.append(29); + data.append(30); + data.append(31); + data.append(32); + data.append(33); + data.append(34); + data.append(35); + TensorTrait::new(shape.span(), data.span()) +} \ No newline at end of file