diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index df8fede8..08514717 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -39,7 +39,22 @@ isdtype = get_xp(np)(_aliases.isdtype) unstack = get_xp(da)(_aliases.unstack) -astype = _aliases.astype + +def astype( + x: Array, + dtype: Dtype, + /, + *, + copy: bool = True, + device: Device | None = None +) -> Array: + # TODO: respect device keyword? + if not copy and dtype == x.dtype: + return x + # dask astype doesn't respect copy=True, + # so call copy manually afterwards + x = x.astype(dtype) + return x.copy() if copy else x # Common aliases