From 32f347f5a850272f4cb4ee3fe08565f5abe8dcf5 Mon Sep 17 00:00:00 2001 From: bokutotu Date: Tue, 24 Sep 2024 04:48:27 +0900 Subject: [PATCH] update --- zenu-matrix/src/operation/max.rs | 3 +++ zenu-test/src/lib.rs | 8 ++++---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/zenu-matrix/src/operation/max.rs b/zenu-matrix/src/operation/max.rs index 27f6b12d..48ab12e1 100644 --- a/zenu-matrix/src/operation/max.rs +++ b/zenu-matrix/src/operation/max.rs @@ -43,6 +43,9 @@ impl MaxIdx for Nvidia { impl, D: Device> Matrix { #[must_use] pub fn max_idx(&self) -> DimDyn { + if self.shape().is_empty() { + return DimDyn::from(&[] as &[usize]); + } let default_stride = self.to_default_stride(); let idx = ::max_idx( default_stride.as_ptr(), diff --git a/zenu-test/src/lib.rs b/zenu-test/src/lib.rs index 6f4dec07..229dbf38 100644 --- a/zenu-test/src/lib.rs +++ b/zenu-test/src/lib.rs @@ -8,15 +8,15 @@ macro_rules! assert_mat_eq_epsilon { let epsilon = $epsilon; let diff = mat.to_ref() - mat2.to_ref(); let abs = diff.abs(); - let diff_asum = abs.asum(); - if diff_asum > epsilon { + let diff_max = abs.max_item(); + if diff_max > epsilon { panic!( "assertion failed: `(left == right)`\n\ left: \n{:?},\n\ right: \n{:?}\n\ diff: \n{:?}\n\ - diff_asum: \n{:?}", - mat, mat2, diff, diff_asum + diff_max: \n{:?}", + mat, mat2, diff, diff_max ); } }};