Skip to content

Commit

Permalink
Cast inputs to Cubed arrays in apply_ufunc
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Aug 13, 2024
1 parent 0e6e3f0 commit 821e30d
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
6 changes: 5 additions & 1 deletion cubed/core/gufunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,11 @@ def apply_gufunc(
# Main code:

# Cast all input arrays to cubed
# args = [asarray(a) for a in args] # TODO: do we need to support casting?
from cubed.array_api.creation_functions import asarray

specs = [a.spec for a in args if hasattr(a, "spec")]
spec = specs[0] if len(specs) > 0 else None
args = [asarray(a, spec=spec) for a in args]

if len(input_coredimss) != len(args):
raise ValueError(
Expand Down
10 changes: 10 additions & 0 deletions cubed/tests/test_gufunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,16 @@ def add(x, y):
assert_equal(z, np.array([2, 4, 6]))


def test_apply_gufunc_elemwise_01_non_cubed_input(spec):
def add(x, y):
return x + y

a = cubed.from_array(np.array([1, 2, 3]), chunks=3, spec=spec)
b = np.array([1, 2, 3])
z = apply_gufunc(add, "(),()->()", a, b, output_dtypes=a.dtype)
assert_equal(z, np.array([2, 4, 6]))


def test_apply_gufunc_elemwise_loop(spec):
def foo(x):
assert x.shape in ((2,), (1,))
Expand Down

0 comments on commit 821e30d

Please sign in to comment.