Skip to content

Commit

Permalink
Fix flaky RuntimeWarning during array reductions (dask#10030)
Browse files Browse the repository at this point in the history
  • Loading branch information
jrbourbeau authored Mar 8, 2023
1 parent c36fe08 commit 5f1fc42
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions dask/array/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ def prod(a, axis=None, dtype=None, keepdims=False, split_every=None, out=None):
if dtype is not None:
dt = dtype
else:
dt = getattr(np.empty((1,), dtype=a.dtype).prod(), "dtype", object)
dt = getattr(np.ones((1,), dtype=a.dtype).prod(), "dtype", object)
return reduction(
a,
chunk.prod,
Expand Down Expand Up @@ -504,7 +504,7 @@ def nansum(a, axis=None, dtype=None, keepdims=False, split_every=None, out=None)
if dtype is not None:
dt = dtype
else:
dt = getattr(chunk.nansum(np.empty((1,), dtype=a.dtype)), "dtype", object)
dt = getattr(chunk.nansum(np.ones((1,), dtype=a.dtype)), "dtype", object)
return reduction(
a,
chunk.nansum,
Expand All @@ -522,7 +522,7 @@ def nanprod(a, axis=None, dtype=None, keepdims=False, split_every=None, out=None
if dtype is not None:
dt = dtype
else:
dt = getattr(chunk.nansum(np.empty((1,), dtype=a.dtype)), "dtype", object)
dt = getattr(chunk.nansum(np.ones((1,), dtype=a.dtype)), "dtype", object)
return reduction(
a,
chunk.nanprod,
Expand Down Expand Up @@ -731,7 +731,7 @@ def nanmean(a, axis=None, dtype=None, keepdims=False, split_every=None, out=None
if dtype is not None:
dt = dtype
else:
dt = getattr(np.mean(np.empty(shape=(1,), dtype=a.dtype)), "dtype", object)
dt = getattr(np.mean(np.ones(shape=(1,), dtype=a.dtype)), "dtype", object)
return reduction(
a,
partial(mean_chunk, sum=chunk.nansum, numel=nannumel),
Expand Down Expand Up @@ -1352,7 +1352,7 @@ def prefixscan_blelloch(func, preop, binop, x, axis=None, dtype=None, out=None):
x = x.flatten().rechunk(chunks=x.npartitions)
axis = 0
if dtype is None:
dtype = getattr(func(np.empty((0,), dtype=x.dtype)), "dtype", object)
dtype = getattr(func(np.ones((0,), dtype=x.dtype)), "dtype", object)
assert isinstance(axis, Integral)
axis = validate_axis(axis, x.ndim)
name = f"{func.__name__}-{tokenize(func, axis, preop, binop, x, dtype)}"
Expand Down Expand Up @@ -1499,7 +1499,7 @@ def cumreduction(
x = x.flatten().rechunk(chunks=x.npartitions)
axis = 0
if dtype is None:
dtype = getattr(func(np.empty((0,), dtype=x.dtype)), "dtype", object)
dtype = getattr(func(np.ones((0,), dtype=x.dtype)), "dtype", object)
assert isinstance(axis, Integral)
axis = validate_axis(axis, x.ndim)

Expand Down

0 comments on commit 5f1fc42

Please sign in to comment.