From 98581079ec98c2a30e62d6fcbca8191c71a38e03 Mon Sep 17 00:00:00 2001 From: Philipp Holl Date: Fri, 9 Aug 2024 22:37:59 +0200 Subject: [PATCH] Fix argmax,argmin dim checks --- phiml/math/_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/phiml/math/_ops.py b/phiml/math/_ops.py index 18173c8..814f4b0 100644 --- a/phiml/math/_ops.py +++ b/phiml/math/_ops.py @@ -1700,7 +1700,7 @@ def argmax(x: Tensor, dim: DimFilter, index_dim=channel('index')): """ dims = x.shape.only(dim) keep = x.shape.without(dims) - assert dim, f"No dimensions {dim} present on key {x.shape}" + assert dims, f"argmax requires dim to be present on data but {dim} does not exist on {x.shape}" if isinstance(x, (SparseCoordinateTensor, CompressedSparseMatrix)): if dims in sparse_dims(x): max_val = max_(x, dim) @@ -1740,7 +1740,7 @@ def argmin(x: Tensor, dim: DimFilter, index_dim=channel('index')): """ dims = x.shape.only(dim) keep = x.shape.without(dims) - assert dim, f"No dimensions {dim} present on key {x.shape}" + assert dims, f"argmin requires dim to be present on data but {dim} does not exist on {x.shape}" if isinstance(x, (SparseCoordinateTensor, CompressedSparseMatrix)): if dims in sparse_dims(x): min_val = min_(x, dim)