Skip to content

Commit

Permalink
ndarray import from buffer protocol requires integer stride. (#489)
Browse files Browse the repository at this point in the history
  • Loading branch information
hpkfft authored and wjakob committed Mar 22, 2024
1 parent 15a33f2 commit c30294a
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 5 deletions.
6 changes: 3 additions & 3 deletions docs/api_extra.rst
Original file line number Diff line number Diff line change
Expand Up @@ -634,21 +634,21 @@ section <ndarrays>`.
.. cpp:function:: size_t itemsize() const

Return the size of a single array element in bytes. The returned value
is rounded to the next full byte in case of bit-level representations
is rounded up to the next full byte in case of bit-level representations
(query :cpp:member:`dtype::bits` for bit-level granularity).

.. cpp:function:: size_t nbytes() const

Return the size of the entire array bytes. The returned value is rounded
to the next full byte in case of bit-level representations.
up to the next full byte in case of bit-level representations.

.. cpp:function:: size_t shape(size_t i) const

Return the size of dimension `i`.

.. cpp:function:: int64_t stride(size_t i) const

Return the stride of dimension `i`.
Return the stride (in number of elements) of dimension `i`.

.. cpp:function:: const int64_t* shape_ptr() const

Expand Down
8 changes: 7 additions & 1 deletion src/nb_ndarray.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,14 @@ static PyObject *dlpack_from_buffer_protocol(PyObject *o, bool ro) {

scoped_pymalloc<int64_t> strides((size_t) view->ndim);
scoped_pymalloc<int64_t> shape((size_t) view->ndim);
const int64_t itemsize = static_cast<int64_t>(view->itemsize);
for (size_t i = 0; i < (size_t) view->ndim; ++i) {
strides[i] = (int64_t) (view->strides[i] / view->itemsize);
int64_t stride = view->strides[i] / itemsize;
if (stride * itemsize != view->strides[i]) {
PyBuffer_Release(view.get());
return nullptr;
}
strides[i] = stride;
shape[i] = (int64_t) view->shape[i];
}

Expand Down
4 changes: 4 additions & 0 deletions tests/test_ndarray.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ NB_MODULE(test_ndarray_ext, m) {
return t.nbytes();
}, "array"_a.noconvert());

m.def("get_stride", [](const nb::ndarray<> &t, size_t i) {
return t.stride(i);
}, "array"_a.noconvert(), "i"_a);

m.def("check_shape_ptr", [](const nb::ndarray<> &t) {
std::vector<int64_t> shape(t.ndim());
std::copy(t.shape_ptr(), t.shape_ptr() + t.ndim(), shape.begin());
Expand Down
27 changes: 26 additions & 1 deletion tests/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ def test28_reference_internal():
assert msg in str(excinfo.value)

@needs_numpy
def test29_force_contig_pytorch():
def test29_force_contig_numpy():
a = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
b = t.make_contig(a)
assert b is a
Expand Down Expand Up @@ -656,3 +656,28 @@ def __dlpack__(self):

arr = DLPackWrapper(np.zeros((1)))
assert t.check(arr)

@needs_numpy
def test37_noninteger_stride():
a = np.array([[1, 2, 3, 4, 0, 0], [5, 6, 7, 8, 0, 0]], dtype=np.float32)
s = a[:, 0:4] # slice
t.pass_float32(s)
assert t.get_stride(s, 0) == 6;
assert t.get_stride(s, 1) == 1;
v = s.view(np.complex64)
t.pass_complex64(v)
assert t.get_stride(v, 0) == 3;
assert t.get_stride(v, 1) == 1;

a = np.array([[1, 2, 3, 4, 0], [5, 6, 7, 8, 0]], dtype=np.float32)
s = a[:, 0:4] # slice
t.pass_float32(s)
assert t.get_stride(s, 0) == 5;
assert t.get_stride(s, 1) == 1;
v = s.view(np.complex64)
with pytest.raises(TypeError) as excinfo:
t.pass_complex64(v)
assert 'incompatible function arguments' in str(excinfo.value)
with pytest.raises(TypeError) as excinfo:
t.get_stride(v, 0);
assert 'incompatible function arguments' in str(excinfo.value)
2 changes: 2 additions & 0 deletions tests/test_ndarray_ext.pyi.ref
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ def get_shape(array: Annotated[ArrayLike, dict(writable=False)]) -> list: ...

def get_size(array: ArrayLike) -> int: ...

def get_stride(array: ArrayLike, i: int) -> int: ...

def implicit(array: Annotated[ArrayLike, dict(dtype='float32', order='C', shape=(2, 2))]) -> int: ...

@overload
Expand Down

0 comments on commit c30294a

Please sign in to comment.