From 40b8e9585910aa05c6bbd4c56d6226feb1bf26fa Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Thu, 21 Sep 2023 13:16:18 +0300 Subject: [PATCH] fix end/start --- src/lib.cairo | 2 +- src/operators/tensor/core.cairo | 28 +++++++++++++++++++--------- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/src/lib.cairo b/src/lib.cairo index 599a350a8..7cc13d360 100644 --- a/src/lib.cairo +++ b/src/lib.cairo @@ -1,5 +1,5 @@ mod operators; mod numbers; -// mod tests; +mod tests; mod utils; diff --git a/src/operators/tensor/core.cairo b/src/operators/tensor/core.cairo index 3ea6c2241..1453f7eaf 100644 --- a/src/operators/tensor/core.cairo +++ b/src/operators/tensor/core.cairo @@ -2537,7 +2537,11 @@ fn tensor_eq>(mut lhs: Tensor, mut rhs: Tens /// Cf: TensorTrait::slice docstring fn slice, impl TCopy: Copy, impl TDrop: Drop>( - self: @Tensor, starts: Span, ends: Span, axes: Option>, steps: Option> + self: @Tensor, + starts: Span, + ends: Span, + axes: Option>, + steps: Option> ) -> Tensor { let axes = match axes { Option::Some(axes) => axes, @@ -2592,15 +2596,22 @@ fn slice, impl TCopy: Copy, impl TDrop: Drop< let mut processed_params = (0, 0, 0, 0); if is_found { let mut start: usize = *(*self.shape).at(i); - let mut end: usize = 0; + let mut end: usize = *(*self.shape).at(i); + if *starts.at(axis_index) < *(*self.shape).at(i) { start = *starts.at(axis_index); } if *(*self.shape).at(i) > *ends.at(axis_index) { end = *ends.at(axis_index); + } else { + end = *(*self.shape).at(i); + } + + if start > *(*self.shape).at(i) { + start = *(*self.shape).at(i); } - else { + if end > *(*self.shape).at(i) { end = *(*self.shape).at(i); } @@ -2615,10 +2626,10 @@ fn slice, impl TCopy: Copy, impl TDrop: Drop< processed_params = (start, end, *steps.at(axis_index), dim); } } - } else { processed_params = (0, *(*self.shape).at(i), 1, *(*self.shape).at(i)); } + let (start, end, step, shape) = processed_params; processed_starts.append(start); processed_ends.append(end); @@ -2634,7 +2645,7 @@ fn slice, impl TCopy: Copy, impl TDrop: Drop< let mut output_data: Array = ArrayTrait::new(); if is_empty { - return Tensor:: {shape: output_shape.span(), data: output_data.span()}; + return Tensor:: { shape: output_shape.span(), data: output_data.span() }; } let stop_j = (*self.data).len() - 1; @@ -2661,8 +2672,7 @@ fn slice, impl TCopy: Copy, impl TDrop: Drop< } if (index - start) % step == 0 { is_included = true; - } - else { + } else { is_included = false; break (); } @@ -2683,5 +2693,5 @@ fn slice, impl TCopy: Copy, impl TDrop: Drop< j += 1; }; - return TensorTrait::new(output_shape.span(), output_data.span()); -} \ No newline at end of file + return TensorTrait::new(output_shape.span(), output_data.span()); +}