From 6d4c65d72c329ece70df0b58aa742372edb864d4 Mon Sep 17 00:00:00 2001 From: Robert Haase Date: Sat, 7 May 2022 22:41:43 +0200 Subject: [PATCH] improve pyopencl interoperability + added test --- pyclesperanto_prototype/_tier0/_create.py | 8 ++++++++ tests/test_pyopencl_compatibility.py | 19 ++++++++++++++++++- 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/pyclesperanto_prototype/_tier0/_create.py b/pyclesperanto_prototype/_tier0/_create.py index a3f4dc75..c2295f96 100644 --- a/pyclesperanto_prototype/_tier0/_create.py +++ b/pyclesperanto_prototype/_tier0/_create.py @@ -29,6 +29,8 @@ def create_like(*args): dimensions = dimensions.shape elif isinstance(dimensions, np.ndarray): dimensions = dimensions.shape[::-1] + elif hasattr(dimensions, "shape"): + dimensions = dimensions.shape return create(dimensions) def create_binary_like(*args): @@ -37,6 +39,8 @@ def create_binary_like(*args): dimensions = dimensions.shape elif isinstance(dimensions, np.ndarray): dimensions = dimensions.shape[::-1] + elif hasattr(dimensions, "shape"): + dimensions = dimensions.shape return create(dimensions, np.uint8) def create_labels_like(*args): @@ -45,6 +49,8 @@ def create_labels_like(*args): dimensions = dimensions.shape elif isinstance(dimensions, np.ndarray): dimensions = dimensions.shape[::-1] + elif hasattr(dimensions, "shape"): + dimensions = dimensions.shape return create(dimensions, np.uint32) def create_same_type_like(*args): @@ -53,6 +59,8 @@ def create_same_type_like(*args): dimensions = dimensions.shape elif isinstance(dimensions, np.ndarray): dimensions = dimensions.shape[::-1] + elif hasattr(dimensions, "shape"): + dimensions = dimensions.shape return create(dimensions, dimensions.dtype) def create_pointlist_from_labelmap(source, *args): diff --git a/tests/test_pyopencl_compatibility.py b/tests/test_pyopencl_compatibility.py index 3841e439..9efe0138 100644 --- a/tests/test_pyopencl_compatibility.py +++ b/tests/test_pyopencl_compatibility.py @@ -15,4 +15,21 @@ def test_pyopencl_compatibility(): cl_c = cl_a + cl_b import pyclesperanto_prototype as cle - cl_c = cl_a + cl_b \ No newline at end of file + cl_c = cl_a + cl_b + +def test_semi_push(): + import numpy as np + img = np.asarray([[1,2],[3,4], [5,6]]) + + import pyclesperanto_prototype as cle + device = cle.get_device() + + import pyopencl.array as cla + pushed = cla.to_device(device.queue, img) + + print(type(pushed)) + print(pushed.shape) + blurred = cle.gaussian_blur(pushed, sigma_x=10, sigma_y=10, sigma_z=10) + + assert np.array_equal(blurred.shape, pushed.shape) +