Skip to content

Commit

Permalink
Cast inputs to Cubed arrays in apply_ufunc (#551)
Browse files Browse the repository at this point in the history
* Cast inputs to Cubed arrays in `apply_ufunc`

* Add comment about specs being the same
  • Loading branch information
tomwhite authored Aug 14, 2024
1 parent 0e6e3f0 commit 19edd81
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
8 changes: 7 additions & 1 deletion cubed/core/gufunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,13 @@ def apply_gufunc(
# Main code:

# Cast all input arrays to cubed
# args = [asarray(a) for a in args] # TODO: do we need to support casting?
# Use a spec if there is one. Note that all args have to have the same spec, and
# this will be checked later when constructing the plan (see check_array_specs).
from cubed.array_api.creation_functions import asarray

specs = [a.spec for a in args if hasattr(a, "spec")]
spec = specs[0] if len(specs) > 0 else None
args = [asarray(a, spec=spec) for a in args]

if len(input_coredimss) != len(args):
raise ValueError(
Expand Down
10 changes: 10 additions & 0 deletions cubed/tests/test_gufunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,16 @@ def add(x, y):
assert_equal(z, np.array([2, 4, 6]))


def test_apply_gufunc_elemwise_01_non_cubed_input(spec):
def add(x, y):
return x + y

a = cubed.from_array(np.array([1, 2, 3]), chunks=3, spec=spec)
b = np.array([1, 2, 3])
z = apply_gufunc(add, "(),()->()", a, b, output_dtypes=a.dtype)
assert_equal(z, np.array([2, 4, 6]))


def test_apply_gufunc_elemwise_loop(spec):
def foo(x):
assert x.shape in ((2,), (1,))
Expand Down

0 comments on commit 19edd81

Please sign in to comment.