From fe3e05c2cd81b97c4fee37df490eda09a605b407 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Wed, 20 Nov 2024 14:24:44 +0200 Subject: [PATCH 1/2] ENH: test vecdot values, incl complex conj --- array_api_tests/test_linalg.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/array_api_tests/test_linalg.py b/array_api_tests/test_linalg.py index 40fe035d..10344ec1 100644 --- a/array_api_tests/test_linalg.py +++ b/array_api_tests/test_linalg.py @@ -45,6 +45,7 @@ from . import _array_module as xp from ._array_module import linalg + def assert_equal(x, y, msg_extra=None): extra = '' if not msg_extra else f' ({msg_extra})' if x.dtype in dh.all_float_dtypes: @@ -60,6 +61,7 @@ def assert_equal(x, y, msg_extra=None): else: assert_exactly_equal(x, y, msg_extra=msg_extra) + def _test_stacks(f, *args, res=None, dims=2, true_val=None, matrix_axes=(-2, -1), res_axes=None, @@ -106,6 +108,7 @@ def _test_stacks(f, *args, res=None, dims=2, true_val=None, if true_val: assert_equal(decomp_res_stack, true_val(*x_stacks, **kw), msg_extra) + def _test_namedtuple(res, fields, func_name): """ Test that res is a namedtuple with the correct fields. @@ -121,6 +124,7 @@ def _test_namedtuple(res, fields, func_name): assert hasattr(res, field), f"{func_name}() result namedtuple doesn't have the '{field}' field" assert res[i] is getattr(res, field), f"{func_name}() result namedtuple '{field}' field is not in position {i}" + @pytest.mark.unvectorized @pytest.mark.xp_extension('linalg') @given( @@ -901,6 +905,15 @@ def true_trace(x_stack, offset=0): _test_stacks(linalg.trace, x, **kw, res=res, dims=0, true_val=true_trace) + +def _conj(x): + """Work around xp.conj rejecting floats.""" + if xp.isdtype(x.dtype, 'complex floating'): + return xp.conj(x) + else: + return x + + def _test_vecdot(namespace, x1, x2, data): vecdot = namespace.vecdot broadcasted_shape = sh.broadcast_shapes(x1.shape, x2.shape) @@ -925,11 +938,8 @@ def _test_vecdot(namespace, x1, x2, data): ph.assert_result_shape("vecdot", in_shapes=[x1.shape, x2.shape], out_shape=res.shape, expected=expected_shape) - if x1.dtype in dh.int_dtypes: - def true_val(x, y, axis=-1): - return xp.sum(xp.multiply(x, y), dtype=res.dtype) - else: - true_val = None + def true_val(x, y, axis=-1): + return xp.sum(xp.multiply(_conj(x), y), dtype=res.dtype) _test_stacks(vecdot, x1, x2, res=res, dims=0, matrix_axes=(axis,), true_val=true_val) @@ -944,6 +954,7 @@ def true_val(x, y, axis=-1): def test_linalg_vecdot(x1, x2, data): _test_vecdot(linalg, x1, x2, data) + @pytest.mark.unvectorized @given( *two_mutual_arrays(dh.numeric_dtypes, mutually_broadcastable_shapes(2, min_dims=1)), @@ -952,10 +963,12 @@ def test_linalg_vecdot(x1, x2, data): def test_vecdot(x1, x2, data): _test_vecdot(_array_module, x1, x2, data) + # Insanely large orders might not work. There isn't a limit specified in the # spec, so we just limit to reasonable values here. max_ord = 100 + @pytest.mark.unvectorized @pytest.mark.xp_extension('linalg') @given( From 6ea8ae29e7e5516d4f1184292069152e2b3fef70 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Wed, 20 Nov 2024 22:49:54 +0200 Subject: [PATCH 2/2] MAINT: inline the check for complex dtypes Preferred as long as xp.isdtype is not universally available (looking at you pytorch) --- array_api_tests/test_linalg.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/array_api_tests/test_linalg.py b/array_api_tests/test_linalg.py index 10344ec1..c997948b 100644 --- a/array_api_tests/test_linalg.py +++ b/array_api_tests/test_linalg.py @@ -907,8 +907,8 @@ def true_trace(x_stack, offset=0): def _conj(x): - """Work around xp.conj rejecting floats.""" - if xp.isdtype(x.dtype, 'complex floating'): + # XXX: replace with xp.dtype when all array libraries implement it + if x.dtype in (xp.complex64, xp.complex128): return xp.conj(x) else: return x