Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes setting of writable flag for views and writing to read-only arrays with out keyword #1527

Merged
merged 6 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions dpctl/tensor/_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,9 @@ def _clip_none(x, val, out, order, _binary_fn):
f"output array must be of usm_ndarray type, got {type(out)}"
)

if not out.flags.writable:
raise ValueError("provided `out` array is read-only")

if out.shape != res_shape:
raise ValueError(
"The shape of input and output arrays are inconsistent. "
Expand Down Expand Up @@ -437,6 +440,9 @@ def clip(x, /, min=None, max=None, out=None, order="K"):
f"{type(out)}"
)

if not out.flags.writable:
raise ValueError("provided `out` array is read-only")

if out.shape != x.shape:
raise ValueError(
"The shape of input and output arrays are "
Expand Down Expand Up @@ -600,6 +606,9 @@ def clip(x, /, min=None, max=None, out=None, order="K"):
f"{type(out)}"
)

if not out.flags.writable:
raise ValueError("provided `out` array is read-only")

if out.shape != res_shape:
raise ValueError(
"The shape of input and output arrays are "
Expand Down
6 changes: 6 additions & 0 deletions dpctl/tensor/_elementwise_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,9 @@ def __call__(self, x, out=None, order="K"):
f"output array must be of usm_ndarray type, got {type(out)}"
)

if not out.flags.writable:
raise ValueError("provided `out` array is read-only")

if out.shape != x.shape:
raise ValueError(
"The shape of input and output arrays are inconsistent. "
Expand Down Expand Up @@ -601,6 +604,9 @@ def __call__(self, o1, o2, out=None, order="K"):
f"output array must be of usm_ndarray type, got {type(out)}"
)

if not out.flags.writable:
raise ValueError("provided `out` array is read-only")

if out.shape != res_shape:
raise ValueError(
"The shape of input and output arrays are inconsistent. "
Expand Down
3 changes: 3 additions & 0 deletions dpctl/tensor/_linear_algebra_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,9 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
f"output array must be of usm_ndarray type, got {type(out)}"
)

if not out.flags.writable:
raise ValueError("provided `out` array is read-only")

if out.shape != res_shape:
raise ValueError(
"The shape of input and output arrays are inconsistent. "
Expand Down
28 changes: 17 additions & 11 deletions dpctl/tensor/_usmarray.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ cdef object _as_zero_dim_ndarray(object usm_ary):
view.shape = tuple()
return view

cdef int _copy_writable(int lhs_flags, int rhs_flags):
"Copy the WRITABLE flag to lhs_flags from rhs_flags"
return (lhs_flags & ~USM_ARRAY_WRITABLE) | (rhs_flags & USM_ARRAY_WRITABLE)

cdef class usm_ndarray:
""" usm_ndarray(shape, dtype=None, strides=None, buffer="device", \
Expand Down Expand Up @@ -546,7 +549,7 @@ cdef class usm_ndarray:
PyMem_Free(self.shape_)
if (self.strides_):
PyMem_Free(self.strides_)
self.flags_ = contig_flag
self.flags_ = (contig_flag | (self.flags_ & USM_ARRAY_WRITABLE))
self.nd_ = new_nd
self.shape_ = shape_ptr
self.strides_ = strides_ptr
Expand Down Expand Up @@ -725,13 +728,13 @@ cdef class usm_ndarray:
buffer=self.base_,
offset=_meta[2]
)
res.flags_ |= (self.flags_ & USM_ARRAY_WRITABLE)
res.array_namespace_ = self.array_namespace_

adv_ind = _meta[3]
adv_ind_start_p = _meta[4]

if adv_ind_start_p < 0:
res.flags_ = _copy_writable(res.flags_, self.flags_)
return res

from ._copy_utils import _extract_impl, _nonzero_impl, _take_multi_index
Expand All @@ -749,6 +752,7 @@ cdef class usm_ndarray:
if not matching:
raise IndexError("boolean index did not match indexed array in dimensions")
res = _extract_impl(res, key_, axis=adv_ind_start_p)
res.flags_ = _copy_writable(res.flags_, self.flags_)
return res

if any(ind.dtype == dpt_bool for ind in adv_ind):
Expand All @@ -758,10 +762,13 @@ cdef class usm_ndarray:
adv_ind_int.extend(_nonzero_impl(ind))
else:
adv_ind_int.append(ind)
return _take_multi_index(res, tuple(adv_ind_int), adv_ind_start_p)

return _take_multi_index(res, adv_ind, adv_ind_start_p)
res = _take_multi_index(res, tuple(adv_ind_int), adv_ind_start_p)
res.flags_ = _copy_writable(res.flags_, self.flags_)
return res

res = _take_multi_index(res, adv_ind, adv_ind_start_p)
res.flags_ = _copy_writable(res.flags_, self.flags_)
return res

def to_device(self, target, stream=None):
""" to_device(target_device)
Expand Down Expand Up @@ -1040,8 +1047,7 @@ cdef class usm_ndarray:
buffer=self.base_,
offset=_meta[2],
)
# set flags and namespace
Xv.flags_ |= (self.flags_ & USM_ARRAY_WRITABLE)
# set namespace
Xv.array_namespace_ = self.array_namespace_

from ._copy_utils import (
Expand Down Expand Up @@ -1225,7 +1231,7 @@ cdef usm_ndarray _real_view(usm_ndarray ary):
offset=offset_elems,
order=('C' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else 'F')
)
r.flags_ |= (ary.flags_ & USM_ARRAY_WRITABLE)
r.flags_ = _copy_writable(r.flags_, ary.flags_)
r.array_namespace_ = ary.array_namespace_
return r

Expand Down Expand Up @@ -1257,7 +1263,7 @@ cdef usm_ndarray _imag_view(usm_ndarray ary):
offset=offset_elems,
order=('C' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else 'F')
)
r.flags_ |= (ary.flags_ & USM_ARRAY_WRITABLE)
r.flags_ = _copy_writable(r.flags_, ary.flags_)
r.array_namespace_ = ary.array_namespace_
return r

Expand All @@ -1277,7 +1283,7 @@ cdef usm_ndarray _transpose(usm_ndarray ary):
order=('F' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else 'C'),
offset=ary.get_offset()
)
r.flags_ |= (ary.flags_ & USM_ARRAY_WRITABLE)
r.flags_ = _copy_writable(r.flags_, ary.flags_)
return r


Expand All @@ -1294,7 +1300,7 @@ cdef usm_ndarray _m_transpose(usm_ndarray ary):
order=('F' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else 'C'),
offset=ary.get_offset()
)
r.flags_ |= (ary.flags_ & USM_ARRAY_WRITABLE)
r.flags_ = _copy_writable(r.flags_, ary.flags_)
return r


Expand Down
19 changes: 19 additions & 0 deletions dpctl/tests/test_usm_ndarray_ctor.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,25 @@ def test_usm_ndarray_flags_bug_gh_1334():
assert r.flags["F"] and r.flags["C"]


def test_usm_ndarray_writable_flag_views():
get_queue_or_skip()
a = dpt.arange(10, dtype="f4")
a.flags["W"] = False

a.shape = (5, 2)
assert not a.flags.writable
assert not a.T.flags.writable
assert not a.mT.flags.writable
assert not a.real.flags.writable
assert not a[0:3].flags.writable

a = dpt.arange(10, dtype="c8")
a.flags["W"] = False

assert not a.real.flags.writable
assert not a.imag.flags.writable


@pytest.mark.parametrize(
"dtype",
[
Expand Down