diff --git a/cubed/core/ops.py b/cubed/core/ops.py index a8fa6acb..145ee320 100644 --- a/cubed/core/ops.py +++ b/cubed/core/ops.py @@ -564,13 +564,16 @@ def map_blocks( *args: "Array", dtype=None, chunks=None, - drop_axis=[], + drop_axis=None, new_axis=None, spec=None, **kwargs, ) -> "Array": """Apply a function to corresponding blocks from multiple input arrays.""" + if drop_axis is None: + drop_axis = [] + # Handle the case where an array is created by calling `map_blocks` with no input arrays if len(args) == 0: from cubed.array_api.creation_functions import empty_virtual_array @@ -618,10 +621,19 @@ def wrap(*a, **kw): def _map_blocks( - func, *args: "Array", dtype=None, chunks=None, drop_axis=[], new_axis=None, **kwargs + func, + *args: "Array", + dtype=None, + chunks=None, + drop_axis=None, + new_axis=None, + **kwargs, ) -> "Array": # based on dask + if drop_axis is None: + drop_axis = [] + new_axes = {} if isinstance(drop_axis, Number):