Skip to content

Commit

Permalink
improve pyopencl interoperability + added test
Browse files Browse the repository at this point in the history
  • Loading branch information
haesleinhuepf committed May 7, 2022
1 parent 87fd2d1 commit 6d4c65d
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
8 changes: 8 additions & 0 deletions pyclesperanto_prototype/_tier0/_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def create_like(*args):
dimensions = dimensions.shape
elif isinstance(dimensions, np.ndarray):
dimensions = dimensions.shape[::-1]
elif hasattr(dimensions, "shape"):
dimensions = dimensions.shape
return create(dimensions)

def create_binary_like(*args):
Expand All @@ -37,6 +39,8 @@ def create_binary_like(*args):
dimensions = dimensions.shape
elif isinstance(dimensions, np.ndarray):
dimensions = dimensions.shape[::-1]
elif hasattr(dimensions, "shape"):
dimensions = dimensions.shape
return create(dimensions, np.uint8)

def create_labels_like(*args):
Expand All @@ -45,6 +49,8 @@ def create_labels_like(*args):
dimensions = dimensions.shape
elif isinstance(dimensions, np.ndarray):
dimensions = dimensions.shape[::-1]
elif hasattr(dimensions, "shape"):
dimensions = dimensions.shape
return create(dimensions, np.uint32)

def create_same_type_like(*args):
Expand All @@ -53,6 +59,8 @@ def create_same_type_like(*args):
dimensions = dimensions.shape
elif isinstance(dimensions, np.ndarray):
dimensions = dimensions.shape[::-1]
elif hasattr(dimensions, "shape"):
dimensions = dimensions.shape
return create(dimensions, dimensions.dtype)

def create_pointlist_from_labelmap(source, *args):
Expand Down
19 changes: 18 additions & 1 deletion tests/test_pyopencl_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,21 @@ def test_pyopencl_compatibility():
cl_c = cl_a + cl_b

import pyclesperanto_prototype as cle
cl_c = cl_a + cl_b
cl_c = cl_a + cl_b

def test_semi_push():
import numpy as np
img = np.asarray([[1,2],[3,4], [5,6]])

import pyclesperanto_prototype as cle
device = cle.get_device()

import pyopencl.array as cla
pushed = cla.to_device(device.queue, img)

print(type(pushed))
print(pushed.shape)
blurred = cle.gaussian_blur(pushed, sigma_x=10, sigma_y=10, sigma_z=10)

assert np.array_equal(blurred.shape, pushed.shape)

0 comments on commit 6d4c65d

Please sign in to comment.