Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: test vecdot values, incl complex conj #314

Merged
merged 2 commits into from
Nov 23, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions array_api_tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from . import _array_module as xp
from ._array_module import linalg


def assert_equal(x, y, msg_extra=None):
extra = '' if not msg_extra else f' ({msg_extra})'
if x.dtype in dh.all_float_dtypes:
Expand All @@ -60,6 +61,7 @@ def assert_equal(x, y, msg_extra=None):
else:
assert_exactly_equal(x, y, msg_extra=msg_extra)


def _test_stacks(f, *args, res=None, dims=2, true_val=None,
matrix_axes=(-2, -1),
res_axes=None,
Expand Down Expand Up @@ -106,6 +108,7 @@ def _test_stacks(f, *args, res=None, dims=2, true_val=None,
if true_val:
assert_equal(decomp_res_stack, true_val(*x_stacks, **kw), msg_extra)


def _test_namedtuple(res, fields, func_name):
"""
Test that res is a namedtuple with the correct fields.
Expand All @@ -121,6 +124,7 @@ def _test_namedtuple(res, fields, func_name):
assert hasattr(res, field), f"{func_name}() result namedtuple doesn't have the '{field}' field"
assert res[i] is getattr(res, field), f"{func_name}() result namedtuple '{field}' field is not in position {i}"


@pytest.mark.unvectorized
@pytest.mark.xp_extension('linalg')
@given(
Expand Down Expand Up @@ -901,6 +905,15 @@ def true_trace(x_stack, offset=0):

_test_stacks(linalg.trace, x, **kw, res=res, dims=0, true_val=true_trace)


def _conj(x):
"""Work around xp.conj rejecting floats."""
if xp.isdtype(x.dtype, 'complex floating'):
ev-br marked this conversation as resolved.
Show resolved Hide resolved
return xp.conj(x)
else:
return x


def _test_vecdot(namespace, x1, x2, data):
vecdot = namespace.vecdot
broadcasted_shape = sh.broadcast_shapes(x1.shape, x2.shape)
Expand All @@ -925,11 +938,8 @@ def _test_vecdot(namespace, x1, x2, data):
ph.assert_result_shape("vecdot", in_shapes=[x1.shape, x2.shape],
out_shape=res.shape, expected=expected_shape)

if x1.dtype in dh.int_dtypes:
def true_val(x, y, axis=-1):
return xp.sum(xp.multiply(x, y), dtype=res.dtype)
else:
true_val = None
def true_val(x, y, axis=-1):
return xp.sum(xp.multiply(_conj(x), y), dtype=res.dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, if you look here, there is no approximate testing done at all for floating-point values https://github.com/data-apis/array-api-tests/pull/314/files?diff=unified#diff-6056c0b3af9cd3ba66387432a17f5f36bbd54220419656441a8b01bcdc4df44bR57.

We should probably add a flag to that helper to allow approximate testing to be enabled. Some functions are impossible to do approximate testing for because they don't even have a single possible output (e.g., eigh could pick completely different eigenvectors and still be correct).

There are helpers used in the elementwise functions that could be reused here for testing floating-point (and complex) closeness. Basically, they test with very large epsilons. Even that would be enough to detect that a library isn't conjugating, which is the real concern for this test specifically.


_test_stacks(vecdot, x1, x2, res=res, dims=0,
matrix_axes=(axis,), true_val=true_val)
Expand All @@ -944,6 +954,7 @@ def true_val(x, y, axis=-1):
def test_linalg_vecdot(x1, x2, data):
_test_vecdot(linalg, x1, x2, data)


@pytest.mark.unvectorized
@given(
*two_mutual_arrays(dh.numeric_dtypes, mutually_broadcastable_shapes(2, min_dims=1)),
Expand All @@ -952,10 +963,12 @@ def test_linalg_vecdot(x1, x2, data):
def test_vecdot(x1, x2, data):
_test_vecdot(_array_module, x1, x2, data)


# Insanely large orders might not work. There isn't a limit specified in the
# spec, so we just limit to reasonable values here.
max_ord = 100


@pytest.mark.unvectorized
@pytest.mark.xp_extension('linalg')
@given(
Expand Down
Loading