From 821e30d47df99726ae8e8ab11a2af37cef89b924 Mon Sep 17 00:00:00 2001 From: Tom White Date: Tue, 13 Aug 2024 18:01:10 +0100 Subject: [PATCH] Cast inputs to Cubed arrays in `apply_ufunc` --- cubed/core/gufunc.py | 6 +++++- cubed/tests/test_gufunc.py | 10 ++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/cubed/core/gufunc.py b/cubed/core/gufunc.py index bc85e8e8..776ce82f 100644 --- a/cubed/core/gufunc.py +++ b/cubed/core/gufunc.py @@ -68,7 +68,11 @@ def apply_gufunc( # Main code: # Cast all input arrays to cubed - # args = [asarray(a) for a in args] # TODO: do we need to support casting? + from cubed.array_api.creation_functions import asarray + + specs = [a.spec for a in args if hasattr(a, "spec")] + spec = specs[0] if len(specs) > 0 else None + args = [asarray(a, spec=spec) for a in args] if len(input_coredimss) != len(args): raise ValueError( diff --git a/cubed/tests/test_gufunc.py b/cubed/tests/test_gufunc.py index a06d2e65..deb7d583 100644 --- a/cubed/tests/test_gufunc.py +++ b/cubed/tests/test_gufunc.py @@ -38,6 +38,16 @@ def add(x, y): assert_equal(z, np.array([2, 4, 6])) +def test_apply_gufunc_elemwise_01_non_cubed_input(spec): + def add(x, y): + return x + y + + a = cubed.from_array(np.array([1, 2, 3]), chunks=3, spec=spec) + b = np.array([1, 2, 3]) + z = apply_gufunc(add, "(),()->()", a, b, output_dtypes=a.dtype) + assert_equal(z, np.array([2, 4, 6])) + + def test_apply_gufunc_elemwise_loop(spec): def foo(x): assert x.shape in ((2,), (1,))