Skip to content

Commit 135df75

Browse files
committed
add explicit shape comparison
1 parent f56a43f commit 135df75

8 files changed

+77
-69
lines changed

dpnp/dpnp_iface_linearalgebra.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1275,13 +1275,13 @@ def vdot(a, b):
12751275
if b.size != 1:
12761276
raise ValueError("The second array should be of size one.")
12771277
a_conj = numpy.conj(a)
1278-
return _call_multiply(a_conj, b)
1278+
return dpnp.squeeze(_call_multiply(a_conj, b))
12791279

12801280
if dpnp.isscalar(b):
12811281
if a.size != 1:
12821282
raise ValueError("The first array should be of size one.")
12831283
a_conj = dpnp.conj(a)
1284-
return _call_multiply(a_conj, b)
1284+
return dpnp.squeeze(_call_multiply(a_conj, b))
12851285

12861286
if a.ndim == 1 and b.ndim == 1:
12871287
return dpnp_dot(a, b, out=None, conjugate=True)

dpnp/dpnp_utils/dpnp_utils_linearalgebra.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1108,7 +1108,7 @@ def dpnp_multiplication(
11081108
result = dpnp.moveaxis(result, (-2, -1), axes_res)
11091109
elif len(axes_res) == 1:
11101110
result = dpnp.moveaxis(result, (-1,), axes_res)
1111-
return dpnp.ascontiguousarray(result)
1111+
return result
11121112

11131113
return dpnp.asarray(result, order=order)
11141114

dpnp/tests/helper.py

+24-12
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,21 @@
1010
from . import config
1111

1212

13+
def _assert_dtype(a_dt, b_dt, check_only_type_kind=False):
14+
15+
if check_only_type_kind:
16+
assert a_dt.kind == b_dt.kind, f"{a_dt.kind} != {b_dt.kind}"
17+
else:
18+
assert a_dt == b_dt, f"{a_dt} != {b_dt}"
19+
20+
1321
def assert_dtype_allclose(
1422
dpnp_arr,
1523
numpy_arr,
1624
check_type=True,
1725
check_only_type_kind=False,
1826
factor=8,
19-
relative_factor=None,
27+
check_shape=True,
2028
):
2129
"""
2230
Assert DPNP and NumPy array based on maximum dtype resolution of input arrays
@@ -40,7 +48,13 @@ def assert_dtype_allclose(
4048
4149
"""
4250

43-
list_64bit_types = [numpy.float64, numpy.complex128]
51+
if check_shape:
52+
if hasattr(numpy_arr, "shape"):
53+
assert dpnp_arr.shape == numpy_arr.shape
54+
else:
55+
# numpy output is scalar, then dpnp is 0-D array
56+
assert dpnp_arr.shape == ()
57+
4458
is_inexact = lambda x: hasattr(x, "dtype") and dpnp.issubdtype(
4559
x.dtype, dpnp.inexact
4660
)
@@ -57,34 +71,32 @@ def assert_dtype_allclose(
5771
else -dpnp.inf
5872
)
5973
tol = factor * max(tol_dpnp, tol_numpy)
60-
assert_allclose(dpnp_arr.asnumpy(), numpy_arr, atol=tol, rtol=tol)
74+
assert_allclose(dpnp_arr, numpy_arr, atol=tol, rtol=tol)
6175
if check_type:
76+
list_64bit_types = [numpy.float64, numpy.complex128]
6277
numpy_arr_dtype = numpy_arr.dtype
6378
dpnp_arr_dtype = dpnp_arr.dtype
6479
dpnp_arr_dev = dpnp_arr.sycl_device
6580

6681
if check_only_type_kind:
67-
assert dpnp_arr_dtype.kind == numpy_arr_dtype.kind
82+
_assert_dtype(dpnp_arr_dtype, numpy_arr_dtype, True)
6883
else:
6984
is_np_arr_f2 = numpy_arr_dtype == numpy.float16
7085

7186
if is_np_arr_f2:
7287
if has_support_aspect16(dpnp_arr_dev):
73-
assert dpnp_arr_dtype == numpy_arr_dtype
88+
_assert_dtype(dpnp_arr_dtype, numpy_arr_dtype, False)
7489
elif (
7590
numpy_arr_dtype not in list_64bit_types
7691
or has_support_aspect64(dpnp_arr_dev)
7792
):
78-
assert dpnp_arr_dtype == numpy_arr_dtype
93+
_assert_dtype(dpnp_arr_dtype, numpy_arr_dtype, False)
7994
else:
80-
assert dpnp_arr_dtype.kind == numpy_arr_dtype.kind
95+
_assert_dtype(dpnp_arr_dtype, numpy_arr_dtype, True)
8196
else:
82-
assert_array_equal(dpnp_arr.asnumpy(), numpy_arr)
97+
assert_array_equal(dpnp_arr, numpy_arr)
8398
if check_type and hasattr(numpy_arr, "dtype"):
84-
if check_only_type_kind:
85-
assert dpnp_arr.dtype.kind == numpy_arr.dtype.kind
86-
else:
87-
assert dpnp_arr.dtype == numpy_arr.dtype
99+
_assert_dtype(dpnp_arr.dtype, numpy_arr.dtype, check_only_type_kind)
88100

89101

90102
def get_integer_dtypes(all_int_types=False, no_unsigned=False):

dpnp/tests/test_amin_amax.py

+21-23
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
@pytest.mark.parametrize("func", ["amax", "amin"])
1111
@pytest.mark.parametrize("keepdims", [True, False])
12-
@pytest.mark.parametrize("dtype", get_all_dtypes())
12+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
1313
def test_amax_amin(func, keepdims, dtype):
1414
a = [
1515
[[-2.0, 3.0], [9.1, 0.2]],
@@ -22,52 +22,50 @@ def test_amax_amin(func, keepdims, dtype):
2222
for axis in range(len(a)):
2323
result = getattr(dpnp, func)(ia, axis=axis, keepdims=keepdims)
2424
expected = getattr(numpy, func)(a, axis=axis, keepdims=keepdims)
25-
assert_allclose(expected, result)
25+
assert_allclose(expected, result, strict=True)
2626

2727

28-
def _get_min_max_input(type, shape):
28+
def _get_min_max_input(dtype, shape):
2929
size = numpy.prod(shape)
30-
a = numpy.arange(size, dtype=type)
30+
a = numpy.arange(size, dtype=dtype)
3131
a[int(size / 2)] = size + 5
32-
if numpy.issubdtype(type, numpy.unsignedinteger):
32+
if numpy.issubdtype(dtype, numpy.unsignedinteger):
3333
a[int(size / 3)] = size
3434
else:
3535
a[int(size / 3)] = -(size + 5)
3636

3737
return a.reshape(shape)
3838

3939

40-
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
40+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True, no_bool=True))
4141
@pytest.mark.parametrize(
42-
"shape", [(4,), (2, 3), (4, 5, 6)], ids=["(4,)", "(2, 3)", "(4, 5, 6)"]
42+
"shape", [(4,), (2, 3), (4, 5, 6)], ids=["1D", "2D", "3D"]
4343
)
4444
def test_amax_diff_shape(dtype, shape):
4545
a = _get_min_max_input(dtype, shape)
46-
4746
ia = dpnp.array(a)
4847

49-
np_res = numpy.amax(a)
50-
dpnp_res = dpnp.amax(ia)
51-
assert_array_equal(dpnp_res, np_res)
48+
expected = numpy.amax(a)
49+
result = dpnp.amax(ia)
50+
assert_array_equal(result, expected, strict=True)
5251

53-
np_res = a.max()
54-
dpnp_res = ia.max()
55-
numpy.testing.assert_array_equal(dpnp_res, np_res)
52+
expected = a.max()
53+
result = ia.max()
54+
assert_array_equal(result, expected, strict=True)
5655

5756

58-
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
57+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True, no_bool=True))
5958
@pytest.mark.parametrize(
60-
"shape", [(4,), (2, 3), (4, 5, 6)], ids=["(4,)", "(2, 3)", "(4, 5, 6)"]
59+
"shape", [(4,), (2, 3), (4, 5, 6)], ids=["1D", "2D", "3D"]
6160
)
6261
def test_amin_diff_shape(dtype, shape):
6362
a = _get_min_max_input(dtype, shape)
64-
6563
ia = dpnp.array(a)
6664

67-
np_res = numpy.amin(a)
68-
dpnp_res = dpnp.amin(ia)
69-
assert_array_equal(dpnp_res, np_res)
65+
expected = numpy.amin(a)
66+
result = dpnp.amin(ia)
67+
assert_array_equal(result, expected, strict=True)
7068

71-
np_res = a.min()
72-
dpnp_res = ia.min()
73-
assert_array_equal(dpnp_res, np_res)
69+
expected = a.min()
70+
result = ia.min()
71+
assert_array_equal(result, expected, strict=True)

dpnp/tests/test_mathematical.py

+13-29
Original file line numberDiff line numberDiff line change
@@ -2091,14 +2091,12 @@ def test_discont(self, dt):
20912091

20922092

20932093
@pytest.mark.usefixtures("suppress_divide_invalid_numpy_warnings")
2094-
@pytest.mark.parametrize(
2095-
"val_type", [bool, int, float], ids=["bool", "int", "float"]
2096-
)
2094+
@pytest.mark.parametrize("val_type", [bool, int, float])
20972095
@pytest.mark.parametrize("data_type", get_all_dtypes())
20982096
@pytest.mark.parametrize(
20992097
"func", ["add", "divide", "multiply", "power", "subtract"]
21002098
)
2101-
@pytest.mark.parametrize("val", [0, 1, 5], ids=["0", "1", "5"])
2099+
@pytest.mark.parametrize("val", [0, 1, 5])
21022100
@pytest.mark.parametrize(
21032101
"array",
21042102
[
@@ -2151,7 +2149,7 @@ def test_op_with_scalar(array, val, func, data_type, val_type):
21512149
assert_allclose(result, expected, rtol=1e-6)
21522150

21532151

2154-
@pytest.mark.parametrize("shape", [(), (3, 2)], ids=["()", "(3, 2)"])
2152+
@pytest.mark.parametrize("shape", [(), (3, 2)], ids=["0D", "2D"])
21552153
@pytest.mark.parametrize("dtype", get_all_dtypes())
21562154
def test_multiply_scalar(shape, dtype):
21572155
np_a = numpy.ones(shape, dtype=dtype)
@@ -2162,7 +2160,7 @@ def test_multiply_scalar(shape, dtype):
21622160
assert_allclose(result, expected)
21632161

21642162

2165-
@pytest.mark.parametrize("shape", [(), (3, 2)], ids=["()", "(3, 2)"])
2163+
@pytest.mark.parametrize("shape", [(), (3, 2)], ids=["0D", "2D"])
21662164
@pytest.mark.parametrize("dtype", get_all_dtypes())
21672165
def test_add_scalar(shape, dtype):
21682166
np_a = numpy.ones(shape, dtype=dtype)
@@ -2173,7 +2171,7 @@ def test_add_scalar(shape, dtype):
21732171
assert_allclose(result, expected)
21742172

21752173

2176-
@pytest.mark.parametrize("shape", [(), (3, 2)], ids=["()", "(3, 2)"])
2174+
@pytest.mark.parametrize("shape", [(), (3, 2)], ids=["0D", "2D"])
21772175
@pytest.mark.parametrize("dtype", get_all_dtypes())
21782176
def test_subtract_scalar(shape, dtype):
21792177
np_a = numpy.ones(shape, dtype=dtype)
@@ -2184,7 +2182,7 @@ def test_subtract_scalar(shape, dtype):
21842182
assert_allclose(result, expected)
21852183

21862184

2187-
@pytest.mark.parametrize("shape", [(), (3, 2)], ids=["()", "(3, 2)"])
2185+
@pytest.mark.parametrize("shape", [(), (3, 2)], ids=["0D", "2D"])
21882186
@pytest.mark.parametrize("dtype", get_all_dtypes())
21892187
def test_divide_scalar(shape, dtype):
21902188
np_a = numpy.ones(shape, dtype=dtype)
@@ -2196,9 +2194,7 @@ def test_divide_scalar(shape, dtype):
21962194

21972195

21982196
@pytest.mark.parametrize(
2199-
"data",
2200-
[[[1.0, -1.0], [0.1, -0.1]], [-2, -1, 0, 1, 2]],
2201-
ids=["[[1., -1.], [0.1, -0.1]]", "[-2, -1, 0, 1, 2]"],
2197+
"data", [[[1.0, -1.0], [0.1, -0.1]], [-2, -1, 0, 1, 2]], ids=["2D", "1D"]
22022198
)
22032199
@pytest.mark.parametrize(
22042200
"dtype", get_all_dtypes(no_bool=True, no_unsigned=True)
@@ -2231,9 +2227,7 @@ def test_negative_boolean():
22312227

22322228

22332229
@pytest.mark.parametrize(
2234-
"data",
2235-
[[[1.0, -1.0], [0.1, -0.1]], [-2, -1, 0, 1, 2]],
2236-
ids=["[[1., -1.], [0.1, -0.1]]", "[-2, -1, 0, 1, 2]"],
2230+
"data", [[[1.0, -1.0], [0.1, -0.1]], [-2, -1, 0, 1, 2]], ids=["2D", "1D"]
22372231
)
22382232
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
22392233
def test_positive(data, dtype):
@@ -2399,13 +2393,13 @@ def test_projection_infinity(self, dtype):
23992393
a = dpnp.array(X, dtype=dtype)
24002394
result = dpnp.proj(a)
24012395
expected = dpnp.array(Y, dtype=dtype)
2402-
assert_dtype_allclose(result, expected)
2396+
assert_array_equal(result, expected, strict=True)
24032397

24042398
# out keyword
24052399
dp_out = dpnp.empty(expected.shape, dtype=expected.dtype)
24062400
result = dpnp.proj(a, out=dp_out)
24072401
assert dp_out is result
2408-
assert_dtype_allclose(result, expected)
2402+
assert_array_equal(result, expected, strict=True)
24092403

24102404
@pytest.mark.parametrize("dtype", get_all_dtypes())
24112405
def test_projection(self, dtype):
@@ -2793,21 +2787,11 @@ def test_bitwise_1array_input():
27932787

27942788
@pytest.mark.parametrize(
27952789
"x_shape",
2796-
[
2797-
(),
2798-
(2),
2799-
(3, 4),
2800-
(3, 4, 5),
2801-
],
2790+
[(), (2), (3, 4), (3, 4, 5)],
28022791
)
28032792
@pytest.mark.parametrize(
28042793
"y_shape",
2805-
[
2806-
(),
2807-
(2),
2808-
(3, 4),
2809-
(3, 4, 5),
2810-
],
2794+
[(), (2), (3, 4), (3, 4, 5)],
28112795
)
28122796
def test_elemenwise_outer(x_shape, y_shape):
28132797
x_np = numpy.random.random(x_shape)
@@ -2830,4 +2814,4 @@ def test_elemenwise_outer_scalar():
28302814
y = dpnp.asarray(s)
28312815
expected = dpnp.add.outer(x, y)
28322816
result = dpnp.add.outer(x, s)
2833-
assert_dtype_allclose(result, expected)
2817+
assert_array_equal(result, expected, strict=True)

dpnp/tests/test_nanfunctions.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,9 @@ def test_allnans(self, dtype, array):
125125

126126
result = getattr(dpnp, self.func)(ia)
127127
expected = getattr(numpy, self.func)(a)
128-
assert_dtype_allclose(result, expected)
128+
# for "0d" case, dpnp returns 0D array, numpy returns 1D array
129+
# Array API indicates that the behavior is unspecified
130+
assert_dtype_allclose(result, expected, check_shape=False)
129131

130132
@pytest.mark.parametrize("axis", [None, 0, 1])
131133
def test_empty(self, axis):

dpnp/tests/test_product.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -902,10 +902,13 @@ def test_strided1(self, dtype, stride):
902902
expected = numpy.matmul(a, a)
903903
assert_dtype_allclose(result, expected, factor=16)
904904

905-
iOUT = dpnp.empty(shape, dtype=result.dtype)
905+
OUT = numpy.empty(shape, dtype=result.dtype)
906+
out = OUT[slices]
907+
iOUT = dpnp.array(OUT)
906908
iout = iOUT[slices]
907909
result = dpnp.matmul(ia, ia, out=iout)
908910
assert result is iout
911+
expected = numpy.matmul(a, a, out=out)
909912
assert_dtype_allclose(result, expected, factor=16)
910913

911914
@pytest.mark.parametrize("dtype", _selected_dtypes)

dpnp/tests/testing/array.py

+9
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,15 @@ def _assert(assert_func, result, expected, *args, **kwargs):
3939
result = convert_item(result)
4040
expected = convert_item(expected)
4141

42+
# original versions of assert_equal, assert_array_equal, and assert_allclose
43+
# (since NumPy 2.0) have `strict` parameter. Added here for
44+
# assert_almost_equal, assert_array_almost_equal, and assert_allclose
45+
# (NumPy < 2.0)
46+
if kwargs.get("strict"):
47+
assert result.dtype == expected.dtype
48+
assert result.shape == expected.shape
49+
kwargs.pop("strict")
50+
4251
assert_func(result, expected, *args, **kwargs)
4352

4453

0 commit comments

Comments
 (0)