Skip to content

Commit

Permalink
fixed gather i8, i32 , u32
Browse files Browse the repository at this point in the history
  • Loading branch information
hakymulla committed Sep 19, 2023
1 parent 7d65ca0 commit fd53d16
Show file tree
Hide file tree
Showing 53 changed files with 1,666 additions and 113 deletions.
195 changes: 175 additions & 20 deletions nodegen/node/gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@
from ..helpers import make_node, make_test, to_fp, Tensor, Dtype, FixedImpl, Trait

class Gather(RunAll):

@staticmethod
def gather_fp16x16():

def gather_3D():
def default():
x1 = np.arange(0,27).reshape(3,3,3).astype(np.int64)
x2 = np.array([[0,1], [2,1], [0, 2]]).astype(np.int64)
x2 = np.array([[0,1], [2,1], [0, 2]]).astype(np.uint32)
y = x1.take(x2, axis=0)

x1 = Tensor(Dtype.FP16x16, x1.shape, to_fp(x1.flatten(), FixedImpl.FP16x16))
x2 = Tensor(Dtype.FP16x16, x2.shape, to_fp(
x2.flatten(), FixedImpl.FP16x16))
x2 = Tensor(Dtype.U32, x2.shape, x2.flatten())
y = Tensor(Dtype.FP16x16, y.shape, to_fp(
y.flatten(), FixedImpl.FP16x16))

Expand All @@ -30,8 +30,7 @@ def axis1():
y = x1.take(x2, axis=1)

x1 = Tensor(Dtype.FP16x16, x1.shape, to_fp(x1.flatten(), FixedImpl.FP16x16))
x2 = Tensor(Dtype.FP16x16, x2.shape, to_fp(
x2.flatten(), FixedImpl.FP16x16))
x2 = Tensor(Dtype.U32, x2.shape, x2.flatten())
y = Tensor(Dtype.FP16x16, y.shape, to_fp(
y.flatten(), FixedImpl.FP16x16))

Expand All @@ -47,8 +46,7 @@ def axis2():
y = x1.take(x2, axis=2)

x1 = Tensor(Dtype.FP16x16, x1.shape, to_fp(x1.flatten(), FixedImpl.FP16x16))
x2 = Tensor(Dtype.FP16x16, x2.shape, to_fp(
x2.flatten(), FixedImpl.FP16x16))
x2 = Tensor(Dtype.U32, x2.shape, x2.flatten())
y = Tensor(Dtype.FP16x16, y.shape, to_fp(
y.flatten(), FixedImpl.FP16x16))

Expand All @@ -62,7 +60,7 @@ def axis2():
axis1()
axis2()
gather_3D()

@staticmethod
def gather_fp8x23():

Expand All @@ -73,10 +71,8 @@ def default():
y = x1.take(x2, axis=0)

x1 = Tensor(Dtype.FP8x23, x1.shape, to_fp(x1.flatten(), FixedImpl.FP8x23))
x2 = Tensor(Dtype.FP8x23, x2.shape, to_fp(
x2.flatten(), FixedImpl.FP8x23))
y = Tensor(Dtype.FP8x23, y.shape, to_fp(
y.flatten(), FixedImpl.FP8x23))
x2 = Tensor(Dtype.U32, x2.shape, x2.flatten())
y = Tensor(Dtype.FP8x23, y.shape, to_fp(y.flatten(), FixedImpl.FP8x23))

name = "gather_fp8x23_3d_default"
make_node([x1, x2], [y], name)
Expand All @@ -90,10 +86,8 @@ def axis1():
y = x1.take(x2, axis=1)

x1 = Tensor(Dtype.FP8x23, x1.shape, to_fp(x1.flatten(), FixedImpl.FP8x23))
x2 = Tensor(Dtype.FP8x23, x2.shape, to_fp(
x2.flatten(), FixedImpl.FP8x23))
y = Tensor(Dtype.FP8x23, y.shape, to_fp(
y.flatten(), FixedImpl.FP8x23))
x2 = Tensor(Dtype.U32, x2.shape, x2.flatten())
y = Tensor(Dtype.FP8x23, y.shape, to_fp(y.flatten(), FixedImpl.FP8x23))

name = "gather_fp8x23_3d_axis1"
make_node([x1, x2], [y], name)
Expand All @@ -107,17 +101,178 @@ def axis2():
y = x1.take(x2, axis=2)

x1 = Tensor(Dtype.FP8x23, x1.shape, to_fp(x1.flatten(), FixedImpl.FP8x23))
x2 = Tensor(Dtype.FP8x23, x2.shape, to_fp(
x2.flatten(), FixedImpl.FP8x23))
y = Tensor(Dtype.FP8x23, y.shape, to_fp(
y.flatten(), FixedImpl.FP8x23))
x2 = Tensor(Dtype.U32, x2.shape, x2.flatten())
y = Tensor(Dtype.FP8x23, y.shape, to_fp(y.flatten(), FixedImpl.FP8x23))

name = "gather_fp8x23_3d_axis2"
make_node([x1, x2], [y], name)
make_test(
inputs = [x1, x2], output = y, func_sig = "input_0.gather(indices:input_1, axis:Option::Some(2))",
file_name= name)

default()
axis1()
axis2()
gather_3D()

@staticmethod
def gather_i8():

def gather_3D():
def default():
x1 = np.arange(0,27).reshape(3,3,3).astype(np.int8)
x2 = np.array([[0,1], [2,1], [0, 2]]).astype(np.int8)
y = x1.take(x2, axis=0)

x1 = Tensor(Dtype.I8, x1.shape, x1.flatten())
x2 = Tensor(Dtype.U32, x2.shape, x2.flatten())
y = Tensor(Dtype.I8, y.shape, y.flatten())

name = "gather_i8_3d_default"
make_node([x1, x2], [y], name)
make_test(
inputs = [x1, x2], output = y, func_sig = "input_0.gather(indices:input_1, axis:Option::Some(0))",
file_name= name)

def axis1():
x1 = np.arange(0,27).reshape(3,3,3).astype(np.int8)
x2 = np.array([[0,1], [2,1], [0, 2]]).astype(np.int8)
y = x1.take(x2, axis=1)

x1 = Tensor(Dtype.I8, x1.shape, x1.flatten())
x2 = Tensor(Dtype.U32, x2.shape, x2.flatten())
y = Tensor(Dtype.I8, y.shape, y.flatten())

name = "gather_i8_3d_axis1"
make_node([x1, x2], [y], name)
make_test(
inputs = [x1, x2], output = y, func_sig = "input_0.gather(indices:input_1, axis:Option::Some(1))",
file_name= name)

def axis2():
x1 = np.arange(0,27).reshape(3,3,3).astype(np.int8)
x2 = np.array([[0,1], [2,1], [0, 2]]).astype(np.int8)
y = x1.take(x2, axis=2)

x1 = Tensor(Dtype.I8, x1.shape, x1.flatten())
x2 = Tensor(Dtype.U32, x2.shape, x2.flatten())
y = Tensor(Dtype.I8, y.shape, y.flatten())

name = "gather_i8_3d_axis2"
make_node([x1, x2], [y], name)
make_test(
inputs = [x1, x2], output = y, func_sig = "input_0.gather(indices:input_1, axis:Option::Some(2))",
file_name= name)

default()
axis1()
axis2()
gather_3D()


@staticmethod
def gather_i32():

def gather_3D():
def default():
x1 = np.arange(0,27).reshape(3,3,3).astype(np.int32)
x2 = np.array([[0,1], [2,1], [0, 2]]).astype(np.int32)
y = x1.take(x2, axis=0)

x1 = Tensor(Dtype.I32, x1.shape, x1.flatten())
x2 = Tensor(Dtype.U32, x2.shape, x2.flatten())
y = Tensor(Dtype.I32, y.shape, y.flatten())

name = "gather_i32_3d_default"
make_node([x1, x2], [y], name)
make_test(
inputs = [x1, x2], output = y, func_sig = "input_0.gather(indices:input_1, axis:Option::Some(0))",
file_name= name)

def axis1():
x1 = np.arange(0,27).reshape(3,3,3).astype(np.int32)
x2 = np.array([[0,1], [2,1], [0, 2]]).astype(np.int32)
y = x1.take(x2, axis=1)

x1 = Tensor(Dtype.I32, x1.shape, x1.flatten())
x2 = Tensor(Dtype.U32, x2.shape, x2.flatten())
y = Tensor(Dtype.I32, y.shape, y.flatten())

name = "gather_i32_3d_axis1"
make_node([x1, x2], [y], name)
make_test(
inputs = [x1, x2], output = y, func_sig = "input_0.gather(indices:input_1, axis:Option::Some(1))",
file_name= name)

def axis2():
x1 = np.arange(0,27).reshape(3,3,3).astype(np.int32)
x2 = np.array([[0,1], [2,1], [0, 2]]).astype(np.int32)
y = x1.take(x2, axis=2)

x1 = Tensor(Dtype.I32, x1.shape, x1.flatten())
x2 = Tensor(Dtype.U32, x2.shape, x2.flatten())
y = Tensor(Dtype.I32, y.shape, y.flatten())

name = "gather_i32_3d_axis2"
make_node([x1, x2], [y], name)
make_test(
inputs = [x1, x2], output = y, func_sig = "input_0.gather(indices:input_1, axis:Option::Some(2))",
file_name= name)

default()
axis1()
axis2()
gather_3D()

@staticmethod
def gather_u32():

def gather_3D():
def default():
x1 = np.arange(0,36).reshape(3,4,3).astype(np.uint32)
x2 = np.array([[0,1], [2,1], [0, 2]]).astype(np.uint32)
y = x1.take(x2, axis=0)

x1 = Tensor(Dtype.U32, x1.shape, x1.flatten())
x2 = Tensor(Dtype.U32, x2.shape, x2.flatten())
y = Tensor(Dtype.U32, y.shape, y.flatten())

name = "gather_u32_3d_default"
make_node([x1, x2], [y], name)
make_test(
inputs = [x1, x2], output = y, func_sig = "input_0.gather(indices:input_1, axis:Option::Some(0))",
file_name= name)

def axis1():
x1 = np.arange(0,36).reshape(3,4,3).astype(np.uint32)
x2 = np.array([[0,1], [2,1], [1, 3]]).astype(np.uint32)
y = x1.take(x2, axis=1)

x1 = Tensor(Dtype.U32, x1.shape, x1.flatten())
x2 = Tensor(Dtype.U32, x2.shape, x2.flatten())
y = Tensor(Dtype.U32, y.shape, y.flatten())

name = "gather_u32_3d_axis1"
make_node([x1, x2], [y], name)
make_test(
inputs = [x1, x2], output = y, func_sig = "input_0.gather(indices:input_1, axis:Option::Some(1))",
file_name= name)

def axis2():
x1 = np.arange(0,36).reshape(3,4,3).astype(np.uint32)
x2 = np.array([[0,1], [2,1], [1, 2]]).astype(np.uint32)
y = x1.take(x2, axis=2)

x1 = Tensor(Dtype.U32, x1.shape, x1.flatten())
x2 = Tensor(Dtype.U32, x2.shape, x2.flatten())
y = Tensor(Dtype.U32, y.shape, y.flatten())

name = "gather_u32_3d_axis2"
make_node([x1, x2], [y], name)
make_test(
inputs = [x1, x2], output = y, func_sig = "input_0.gather(indices:input_1, axis:Option::Some(2))",
file_name= name)

default()
axis1()
axis2()
Expand Down
2 changes: 1 addition & 1 deletion src/operators/tensor/core.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -2466,7 +2466,7 @@ trait TensorTrait<T> {
/// ```
///
fn gather(
self: @Tensor<T>, indices: Tensor<T>, axis: Option<usize>
self: @Tensor<T>, indices: Tensor<usize>, axis: Option<usize>
) -> Tensor<T> ;
}

Expand Down
2 changes: 1 addition & 1 deletion src/operators/tensor/implementations/tensor_fp16x16.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ impl FP16x16Tensor of TensorTrait<FP16x16> {
}

fn gather(
self: @Tensor<FP16x16>, indices: Tensor<FP16x16>, axis: Option<usize>
self: @Tensor<FP16x16>, indices: Tensor<usize>, axis: Option<usize>
) -> Tensor<FP16x16> {
math::gather::gather(self, indices, axis)
}
Expand Down
2 changes: 1 addition & 1 deletion src/operators/tensor/implementations/tensor_fp32x32.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ impl FP32x32Tensor of TensorTrait<FP32x32> {
}

fn gather(
self: @Tensor<FP32x32>, indices: Tensor<FP32x32>, axis: Option<usize>
self: @Tensor<FP32x32>, indices: Tensor<usize>, axis: Option<usize>
) -> Tensor<FP32x32> {
math::gather::gather(self, indices, axis)
}
Expand Down
2 changes: 1 addition & 1 deletion src/operators/tensor/implementations/tensor_fp64x64.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ impl FP64x64Tensor of TensorTrait<FP64x64> {
}

fn gather(
self: @Tensor<FP64x64>, indices: Tensor<FP64x64>, axis: Option<usize>
self: @Tensor<FP64x64>, indices: Tensor<usize>, axis: Option<usize>
) -> Tensor<FP64x64> {
math::gather::gather(self, indices, axis)
}
Expand Down
2 changes: 1 addition & 1 deletion src/operators/tensor/implementations/tensor_fp8x23.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ impl FP8x23Tensor of TensorTrait<FP8x23> {
}

fn gather(
self: @Tensor<FP8x23>, indices: Tensor<FP8x23>, axis: Option<usize>
self: @Tensor<FP8x23>, indices: Tensor<usize>, axis: Option<usize>
) -> Tensor<FP8x23> {
math::gather::gather(self, indices, axis)
}
Expand Down
5 changes: 2 additions & 3 deletions src/operators/tensor/implementations/tensor_i32.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,9 @@ impl I32Tensor of TensorTrait<i32> {
}

fn gather(
self: @Tensor<i32>, indices: Tensor<i32>, axis: Option<usize>
self: @Tensor<i32>, indices: Tensor<usize>, axis: Option<usize>
) -> Tensor<i32> {
// math::gather::gather(self, indices, axis)
panic(array!['not supported!'])
math::gather::gather(self, indices, axis)
}
}

Expand Down
5 changes: 2 additions & 3 deletions src/operators/tensor/implementations/tensor_i8.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,9 @@ impl I8Tensor of TensorTrait<i8> {
}

fn gather(
self: @Tensor<i8>, indices: Tensor<i8>, axis: Option<usize>
self: @Tensor<i8>, indices: Tensor<usize>, axis: Option<usize>
) -> Tensor<i8> {
// math::gather::gather(self, indices, axis)
panic(array!['not supported!'])
math::gather::gather(self, indices, axis)
}

}
Expand Down
7 changes: 3 additions & 4 deletions src/operators/tensor/implementations/tensor_u32.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,9 @@ impl U32Tensor of TensorTrait<u32> {
}

fn gather(
self: @Tensor<u32>, indices: Tensor<u32>, axis: Option<usize>
) -> Tensor<u32> {
// math::gather::gather(self, indices, axis)
panic(array!['not supported!'])
self: @Tensor<u32>, indices: Tensor<usize>, axis: Option<usize>
) -> Tensor<u32> {
math::gather::gather(self, indices, axis)
}
}

Expand Down
Loading

0 comments on commit fd53d16

Please sign in to comment.