Skip to content

Commit

Permalink
Merge pull request #401 from chachaleo/transpose2D
Browse files Browse the repository at this point in the history
addition transpose 2D
  • Loading branch information
raphaelDkhn authored Oct 25, 2023
2 parents 9e9fd37 + 5bd2405 commit eec864d
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 0 deletions.
38 changes: 38 additions & 0 deletions src/operators/tensor/linalg/transpose.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ fn transpose<T, impl TTensor: TensorTrait<T>, impl TCopy: Copy<T>, impl TDrop: D
assert((*self.shape).len() > 1, 'cannot transpose a 1D tensor');
assert(axes.len() == (*self.shape).len(), 'shape and axes length unequal');

if (*self.shape).len() == 2 {
return transpose2D(@(*self));
}

let output_shape = permutation_output_shape(*self.shape, axes);
let output_data_len = len_from_shape(output_shape);

Expand Down Expand Up @@ -47,3 +51,37 @@ fn transpose<T, impl TTensor: TensorTrait<T>, impl TCopy: Copy<T>, impl TDrop: D

return TensorTrait::new(output_shape, output_data.span());
}


fn transpose2D<T, impl TTensor: TensorTrait<T>, impl TCopy: Copy<T>, impl TDrop: Drop<T>>(
self: @Tensor<T>
) -> Tensor<T> {
assert((*self.shape).len() == 2, 'transpose a 2D tensor');

let mut output_data = ArrayTrait::new();
let mut output_shape = ArrayTrait::new();

let n = *self.shape[0];
let m = *self.shape[1];

output_shape.append(m);
output_shape.append(n);

let mut j: usize = 0;
loop {
if j == m {
break ();
}
let mut i = 0;
loop {
if i == n {
break ();
}
output_data.append(*(*self.data)[i * m + j]);
i += 1;
};
j += 1;
};

return TensorTrait::new(output_shape.span(), output_data.span());
}
1 change: 1 addition & 0 deletions tests/src/lib.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ mod tensor_core;
mod nodes;
mod helpers;
mod ml;
mod operators;
1 change: 1 addition & 0 deletions tests/src/operators.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
mod transpose_test;
41 changes: 41 additions & 0 deletions tests/src/operators/transpose_test.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
use array::{ArrayTrait, SpanTrait};
use orion::operators::tensor::{TensorTrait, Tensor, U32Tensor};
use debug::PrintTrait;


#[test]
#[available_gas(200000000000)]
fn transpose_test_shape() {
let tensor = TensorTrait::<u32>::new(
shape: array![4, 2].span(), data: array![0, 1, 2, 3, 4, 5, 6, 7].span(),
);

let result = tensor.transpose(axes: array![1, 0].span());
assert(result.shape == array![2, 4].span(), 'wrong dim');
}

#[test]
#[available_gas(200000000000)]
fn transpose_test_values() {
let tensor = TensorTrait::<u32>::new(
shape: array![4, 2].span(), data: array![0, 1, 2, 3, 4, 5, 6, 7].span(),
);

let result = tensor.transpose(axes: array![1, 0].span());
assert(result.data == array![0, 2, 4, 6, 1, 3, 5, 7].span(), 'wrong data');
}


#[test]
#[available_gas(200000000000)]
fn transpose_test_3D() {
let tensor = TensorTrait::<u32>::new(
shape: array![2, 2, 2].span(), data: array![0, 1, 2, 3, 4, 5, 6, 7].span(),
);

let result = tensor.transpose(axes: array![1, 2, 0].span());

assert(result.shape == array![2, 2, 2].span(), 'wrong shape');
assert(result.data == array![0, 4, 1, 5, 2, 6, 3, 7].span(), 'wrong data');
}

0 comments on commit eec864d

Please sign in to comment.