Skip to content

Commit ef16c38

Browse files
committed
address comments
1 parent 1d360ad commit ef16c38

File tree

2 files changed

+14
-12
lines changed

2 files changed

+14
-12
lines changed

dpnp/linalg/dpnp_utils_linalg.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -1185,6 +1185,9 @@ def _norm_int_axis(x, ord, axis, keepdims):
11851185
"""
11861186

11871187
if ord == dpnp.inf:
1188+
if x.shape[axis] == 0:
1189+
x = dpnp.moveaxis(x, axis, -1)
1190+
return dpnp.zeros_like(x, shape=x.shape[:-1])
11881191
return dpnp.abs(x).max(axis=axis, keepdims=keepdims)
11891192
if ord == -dpnp.inf:
11901193
return dpnp.abs(x).min(axis=axis, keepdims=keepdims)
@@ -1220,6 +1223,10 @@ def _norm_tuple_axis(x, ord, row_axis, col_axis, keepdims):
12201223
"""
12211224

12221225
axis = (row_axis, col_axis)
1226+
flag = x.shape[row_axis] == 0 or x.shape[col_axis] == 0
1227+
if flag and ord in [1, 2, dpnp.inf]:
1228+
x = dpnp.moveaxis(x, axis, (-2, -1))
1229+
return dpnp.zeros_like(x, shape=x.shape[:-2])
12231230
if row_axis == col_axis:
12241231
raise ValueError("Duplicate axes given.")
12251232
if ord == 2:
@@ -1251,8 +1258,8 @@ def _norm_tuple_axis(x, ord, row_axis, col_axis, keepdims):
12511258

12521259
if keepdims:
12531260
ret_shape = list(x.shape)
1254-
ret_shape[axis[0]] = 1
1255-
ret_shape[axis[1]] = 1
1261+
ret_shape[row_axis] = 1
1262+
ret_shape[col_axis] = 1
12561263
ret = ret.reshape(ret_shape)
12571264
return ret
12581265

@@ -2401,17 +2408,10 @@ def dpnp_norm(x, ord=None, axis=None, keepdims=False):
24012408
axis = (axis,)
24022409

24032410
if len(axis) == 1:
2404-
if x.shape[axis[0]] == 0 and ord in [1, 2, dpnp.inf]:
2405-
x = dpnp.moveaxis(x, axis, -1)
2406-
return dpnp.zeros_like(x, shape=x.shape[:-1])
24072411
axis = normalize_axis_index(axis[0], ndim)
24082412
return _norm_int_axis(x, ord, axis, keepdims)
24092413

24102414
if len(axis) == 2:
2411-
flag = x.shape[axis[0]] == 0 or x.shape[axis[1]] == 0
2412-
if flag and ord in ["fro", "nuc", 1, 2, dpnp.inf]:
2413-
x = dpnp.moveaxis(x, axis, (-2, -1))
2414-
return dpnp.zeros_like(x, shape=x.shape[:-2])
24152415
row_axis, col_axis = axis
24162416
row_axis = normalize_axis_index(row_axis, ndim)
24172417
col_axis = normalize_axis_index(col_axis, ndim)

dpnp/tests/test_linalg.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -2097,13 +2097,13 @@ def test_empty(self, shape, ord, axis, keepdims):
20972097
assert_raises(ValueError, dpnp.linalg.norm, ia, **kwarg)
20982098
assert_raises(ValueError, numpy.linalg.norm, a, **kwarg)
20992099
elif axis is None and a.ndim != 1 and a.shape[-1] == 0:
2100-
# TODO: when similar changes in numpy are available,
2101-
# instead of assert_equal with zero, we should compare with numpy
21022100
if ord in [-2, -1, 0, 3]:
21032101
# reduction cannot be performed over zero-size axes
21042102
assert_raises(ValueError, dpnp.linalg.norm, ia, **kwarg)
21052103
assert_raises(ValueError, numpy.linalg.norm, a, **kwarg)
21062104
else:
2105+
# TODO: when similar changes in numpy are available, instead
2106+
# of assert_equal with zero, we should compare with numpy
21072107
# ord in [None, 1, 2]
21082108
assert_equal(dpnp.linalg.norm(ia, **kwarg), 0)
21092109
else:
@@ -2295,14 +2295,15 @@ def test_matrix_norm(self, ord, keepdims):
22952295

22962296
@pytest.mark.parametrize("dtype", [dpnp.float32, dpnp.int32])
22972297
@pytest.mark.parametrize(
2298-
"shape_axis", [[(2, 0), None], [(2, 0, 3), (0, 1)]]
2298+
"shape_axis", [[(2, 0), None], [(2, 0), (0, 1)], [(0, 2), (0, 1)]]
22992299
)
23002300
def test_matrix_norm_empty(self, dtype, shape_axis):
23012301
shape, axis = shape_axis[0], shape_axis[1]
23022302
x = dpnp.zeros(shape, dtype=dtype)
23032303

23042304
# TODO: when similar changes in numpy are available,
23052305
# instead of assert_equal with zero, we should compare with numpy
2306+
assert_equal(dpnp.linalg.norm(x, axis=axis), 0)
23062307
assert_equal(dpnp.linalg.norm(x, axis=axis, ord="fro"), 0)
23072308
assert_equal(dpnp.linalg.norm(x, axis=axis, ord="nuc"), 0)
23082309
assert_equal(dpnp.linalg.norm(x, axis=axis, ord=2), 0)
@@ -2315,6 +2316,7 @@ def test_vector_norm_empty(self, dtype, axis):
23152316
x = dpnp.zeros(0, dtype=dtype)
23162317
# TODO: when similar changes in numpy are available,
23172318
# instead of assert_equal with zero, we should compare with numpy
2319+
assert_equal(dpnp.linalg.vector_norm(x, axis=axis), 0)
23182320
assert_equal(dpnp.linalg.vector_norm(x, axis=axis, ord=1), 0)
23192321
assert_equal(dpnp.linalg.vector_norm(x, axis=axis, ord=2), 0)
23202322
assert_equal(dpnp.linalg.vector_norm(x, axis=axis, ord=dpnp.inf), 0)

0 commit comments

Comments
 (0)