Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow dtype to be passed to blockwise and reduction functions #321

Merged
merged 4 commits into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 26 additions & 6 deletions cubed/array_api/statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def mean(x, /, *, axis=None, keepdims=False):
# outputs.
dtype = x.dtype
intermediate_dtype = [("n", np.int64), ("total", np.float64)]
extra_func_kwargs = dict(dtype=intermediate_dtype)
return reduction(
x,
_mean_func,
Expand All @@ -44,18 +45,21 @@ def mean(x, /, *, axis=None, keepdims=False):
intermediate_dtype=intermediate_dtype,
dtype=dtype,
keepdims=keepdims,
extra_func_kwargs=extra_func_kwargs,
)


def _mean_func(a, **kwargs):
n = _numel(a, **kwargs)
total = np.sum(a, **kwargs)
dtype = dict(kwargs.pop("dtype"))
n = _numel(a, dtype=dtype["n"], **kwargs)
total = np.sum(a, dtype=dtype["total"], **kwargs)
return {"n": n, "total": total}


def _mean_combine(a, **kwargs):
n = np.sum(a["n"], **kwargs)
total = np.sum(a["total"], **kwargs)
dtype = dict(kwargs.pop("dtype"))
n = np.sum(a["n"], dtype=dtype["n"], **kwargs)
total = np.sum(a["total"], dtype=dtype["total"], **kwargs)
return {"n": n, "total": total}


Expand Down Expand Up @@ -114,7 +118,15 @@ def prod(x, /, *, axis=None, dtype=None, keepdims=False):
dtype = complex128
else:
dtype = x.dtype
return reduction(x, np.prod, axis=axis, dtype=dtype, keepdims=keepdims)
extra_func_kwargs = dict(dtype=dtype)
return reduction(
x,
np.prod,
axis=axis,
dtype=dtype,
keepdims=keepdims,
extra_func_kwargs=extra_func_kwargs,
)


def sum(x, /, *, axis=None, dtype=None, keepdims=False):
Expand All @@ -131,4 +143,12 @@ def sum(x, /, *, axis=None, dtype=None, keepdims=False):
dtype = complex128
else:
dtype = x.dtype
return reduction(x, np.sum, axis=axis, dtype=dtype, keepdims=keepdims)
extra_func_kwargs = dict(dtype=dtype)
return reduction(
x,
np.sum,
axis=axis,
dtype=dtype,
keepdims=keepdims,
extra_func_kwargs=extra_func_kwargs,
)
5 changes: 5 additions & 0 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def blockwise(
new_axes=None,
align_arrays=True,
target_store=None,
extra_func_kwargs=None,
**kwargs,
) -> "Array":
arrays = args[::2]
Expand Down Expand Up @@ -277,6 +278,7 @@ def blockwise(
new_axes=new_axes,
in_names=in_names,
out_name=name,
extra_func_kwargs=extra_func_kwargs,
**kwargs,
)
plan = Plan._new(
Expand Down Expand Up @@ -712,6 +714,7 @@ def reduction(
intermediate_dtype=None,
dtype=None,
keepdims=False,
extra_func_kwargs=None,
) -> "Array":
if combine_func is None:
combine_func = func
Expand Down Expand Up @@ -742,6 +745,7 @@ def reduction(
keepdims=True,
dtype=intermediate_dtype,
adjust_chunks=adjust_chunks,
extra_func_kwargs=extra_func_kwargs,
)

# merge/reduce along axis in multiple rounds until there's a single block in each reduction axis
Expand Down Expand Up @@ -783,6 +787,7 @@ def reduction(
keepdims=True,
dtype=intermediate_dtype,
adjust_chunks=adjust_chunks,
extra_func_kwargs=extra_func_kwargs,
)

if aggegrate_func is not None:
Expand Down
7 changes: 6 additions & 1 deletion cubed/primitive/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def blockwise(
in_names: Optional[List[str]] = None,
out_name: Optional[str] = None,
extra_projected_mem: int = 0,
extra_func_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
):
"""Apply a function across blocks from multiple source Zarr arrays.
Expand Down Expand Up @@ -127,6 +128,9 @@ def blockwise(
extra_projected_mem : int
Extra memory projected to be needed (in bytes) in addition to the memory used reading
the input arrays and writing the output.
extra_func_kwargs : dict
Extra keyword arguments to pass to function that can't be passed as regular keyword arguments
since they clash with other blockwise arguments (such as dtype).
**kwargs : dict
Extra keyword arguments to pass to function

Expand Down Expand Up @@ -197,7 +201,8 @@ def blockwise(
shape, dtype=dtype, chunks=chunksize, store=target_store
)

func_with_kwargs = partial(func, **kwargs)
func_kwargs = extra_func_kwargs or {}
func_with_kwargs = partial(func, **{**kwargs, **func_kwargs})
read_proxies = {
name: CubedArrayProxy(array, array.chunks) for name, array in array_map.items()
}
Expand Down
3 changes: 2 additions & 1 deletion cubed/tests/test_gufunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ def spec(tmp_path):
@pytest.mark.parametrize("vectorize", [False, True])
def test_apply_reduction(spec, vectorize):
def stats(x):
return np.mean(x, axis=-1)
# note dtype matches output_dtypes in apply_gufunc below
return np.mean(x, axis=-1, dtype=np.float32)

r = np.random.normal(size=(10, 20, 30))
a = cubed.from_array(r, chunks=(5, 5, 30), spec=spec)
Expand Down
Loading