Skip to content

Commit

Permalink
fix end/start
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaelDkhn committed Sep 21, 2023
1 parent 47db631 commit 40b8e95
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 10 deletions.
2 changes: 1 addition & 1 deletion src/lib.cairo
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
mod operators;
mod numbers;
// mod tests;
mod tests;
mod utils;

28 changes: 19 additions & 9 deletions src/operators/tensor/core.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -2537,7 +2537,11 @@ fn tensor_eq<T, impl TPartialEq: PartialEq<T>>(mut lhs: Tensor<T>, mut rhs: Tens

/// Cf: TensorTrait::slice docstring
fn slice<T, impl TTensor: TensorTrait<T>, impl TCopy: Copy<T>, impl TDrop: Drop<T>>(
self: @Tensor<T>, starts: Span<usize>, ends: Span<usize>, axes: Option<Span<usize>>, steps: Option<Span<usize>>
self: @Tensor<T>,
starts: Span<usize>,
ends: Span<usize>,
axes: Option<Span<usize>>,
steps: Option<Span<usize>>
) -> Tensor<T> {
let axes = match axes {
Option::Some(axes) => axes,
Expand Down Expand Up @@ -2592,15 +2596,22 @@ fn slice<T, impl TTensor: TensorTrait<T>, impl TCopy: Copy<T>, impl TDrop: Drop<
let mut processed_params = (0, 0, 0, 0);
if is_found {
let mut start: usize = *(*self.shape).at(i);
let mut end: usize = 0;
let mut end: usize = *(*self.shape).at(i);

if *starts.at(axis_index) < *(*self.shape).at(i) {
start = *starts.at(axis_index);
}

if *(*self.shape).at(i) > *ends.at(axis_index) {
end = *ends.at(axis_index);
} else {
end = *(*self.shape).at(i);
}

if start > *(*self.shape).at(i) {
start = *(*self.shape).at(i);
}
else {
if end > *(*self.shape).at(i) {
end = *(*self.shape).at(i);
}

Expand All @@ -2615,10 +2626,10 @@ fn slice<T, impl TTensor: TensorTrait<T>, impl TCopy: Copy<T>, impl TDrop: Drop<
processed_params = (start, end, *steps.at(axis_index), dim);
}
}

} else {
processed_params = (0, *(*self.shape).at(i), 1, *(*self.shape).at(i));
}

let (start, end, step, shape) = processed_params;
processed_starts.append(start);
processed_ends.append(end);
Expand All @@ -2634,7 +2645,7 @@ fn slice<T, impl TTensor: TensorTrait<T>, impl TCopy: Copy<T>, impl TDrop: Drop<
let mut output_data: Array<T> = ArrayTrait::new();

if is_empty {
return Tensor::<T> {shape: output_shape.span(), data: output_data.span()};
return Tensor::<T> { shape: output_shape.span(), data: output_data.span() };
}

let stop_j = (*self.data).len() - 1;
Expand All @@ -2661,8 +2672,7 @@ fn slice<T, impl TTensor: TensorTrait<T>, impl TCopy: Copy<T>, impl TDrop: Drop<
}
if (index - start) % step == 0 {
is_included = true;
}
else {
} else {
is_included = false;
break ();
}
Expand All @@ -2683,5 +2693,5 @@ fn slice<T, impl TTensor: TensorTrait<T>, impl TCopy: Copy<T>, impl TDrop: Drop<
j += 1;
};

return TensorTrait::new(output_shape.span(), output_data.span());
}
return TensorTrait::new(output_shape.span(), output_data.span());
}

0 comments on commit 40b8e95

Please sign in to comment.