diff --git a/dpctl/tensor/_clip.py b/dpctl/tensor/_clip.py index d95c0fa764..9a310df618 100644 --- a/dpctl/tensor/_clip.py +++ b/dpctl/tensor/_clip.py @@ -262,6 +262,9 @@ def _clip_none(x, val, out, order, _binary_fn): f"output array must be of usm_ndarray type, got {type(out)}" ) + if not out.flags.writable: + raise ValueError("provided `out` array is read-only") + if out.shape != res_shape: raise ValueError( "The shape of input and output arrays are inconsistent. " @@ -437,6 +440,9 @@ def clip(x, /, min=None, max=None, out=None, order="K"): f"{type(out)}" ) + if not out.flags.writable: + raise ValueError("provided `out` array is read-only") + if out.shape != x.shape: raise ValueError( "The shape of input and output arrays are " @@ -600,6 +606,9 @@ def clip(x, /, min=None, max=None, out=None, order="K"): f"{type(out)}" ) + if not out.flags.writable: + raise ValueError("provided `out` array is read-only") + if out.shape != res_shape: raise ValueError( "The shape of input and output arrays are " diff --git a/dpctl/tensor/_elementwise_common.py b/dpctl/tensor/_elementwise_common.py index 4f0e7bfd37..4a812309a1 100644 --- a/dpctl/tensor/_elementwise_common.py +++ b/dpctl/tensor/_elementwise_common.py @@ -202,6 +202,9 @@ def __call__(self, x, out=None, order="K"): f"output array must be of usm_ndarray type, got {type(out)}" ) + if not out.flags.writable: + raise ValueError("provided `out` array is read-only") + if out.shape != x.shape: raise ValueError( "The shape of input and output arrays are inconsistent. " @@ -601,6 +604,9 @@ def __call__(self, o1, o2, out=None, order="K"): f"output array must be of usm_ndarray type, got {type(out)}" ) + if not out.flags.writable: + raise ValueError("provided `out` array is read-only") + if out.shape != res_shape: raise ValueError( "The shape of input and output arrays are inconsistent. " diff --git a/dpctl/tensor/_linear_algebra_functions.py b/dpctl/tensor/_linear_algebra_functions.py index 0894ac2077..9f8aef8a48 100644 --- a/dpctl/tensor/_linear_algebra_functions.py +++ b/dpctl/tensor/_linear_algebra_functions.py @@ -738,6 +738,9 @@ def matmul(x1, x2, out=None, dtype=None, order="K"): f"output array must be of usm_ndarray type, got {type(out)}" ) + if not out.flags.writable: + raise ValueError("provided `out` array is read-only") + if out.shape != res_shape: raise ValueError( "The shape of input and output arrays are inconsistent. " diff --git a/dpctl/tensor/_usmarray.pyx b/dpctl/tensor/_usmarray.pyx index 67e144f798..5a8220db0c 100644 --- a/dpctl/tensor/_usmarray.pyx +++ b/dpctl/tensor/_usmarray.pyx @@ -60,6 +60,9 @@ cdef object _as_zero_dim_ndarray(object usm_ary): view.shape = tuple() return view +cdef int _copy_writable(int lhs_flags, int rhs_flags): + "Copy the WRITABLE flag to lhs_flags from rhs_flags" + return (lhs_flags & ~USM_ARRAY_WRITABLE) | (rhs_flags & USM_ARRAY_WRITABLE) cdef class usm_ndarray: """ usm_ndarray(shape, dtype=None, strides=None, buffer="device", \ @@ -546,7 +549,7 @@ cdef class usm_ndarray: PyMem_Free(self.shape_) if (self.strides_): PyMem_Free(self.strides_) - self.flags_ = contig_flag + self.flags_ = (contig_flag | (self.flags_ & USM_ARRAY_WRITABLE)) self.nd_ = new_nd self.shape_ = shape_ptr self.strides_ = strides_ptr @@ -725,13 +728,13 @@ cdef class usm_ndarray: buffer=self.base_, offset=_meta[2] ) - res.flags_ |= (self.flags_ & USM_ARRAY_WRITABLE) res.array_namespace_ = self.array_namespace_ adv_ind = _meta[3] adv_ind_start_p = _meta[4] if adv_ind_start_p < 0: + res.flags_ = _copy_writable(res.flags_, self.flags_) return res from ._copy_utils import _extract_impl, _nonzero_impl, _take_multi_index @@ -749,6 +752,7 @@ cdef class usm_ndarray: if not matching: raise IndexError("boolean index did not match indexed array in dimensions") res = _extract_impl(res, key_, axis=adv_ind_start_p) + res.flags_ = _copy_writable(res.flags_, self.flags_) return res if any(ind.dtype == dpt_bool for ind in adv_ind): @@ -758,10 +762,13 @@ cdef class usm_ndarray: adv_ind_int.extend(_nonzero_impl(ind)) else: adv_ind_int.append(ind) - return _take_multi_index(res, tuple(adv_ind_int), adv_ind_start_p) - - return _take_multi_index(res, adv_ind, adv_ind_start_p) + res = _take_multi_index(res, tuple(adv_ind_int), adv_ind_start_p) + res.flags_ = _copy_writable(res.flags_, self.flags_) + return res + res = _take_multi_index(res, adv_ind, adv_ind_start_p) + res.flags_ = _copy_writable(res.flags_, self.flags_) + return res def to_device(self, target, stream=None): """ to_device(target_device) @@ -1040,8 +1047,7 @@ cdef class usm_ndarray: buffer=self.base_, offset=_meta[2], ) - # set flags and namespace - Xv.flags_ |= (self.flags_ & USM_ARRAY_WRITABLE) + # set namespace Xv.array_namespace_ = self.array_namespace_ from ._copy_utils import ( @@ -1225,7 +1231,7 @@ cdef usm_ndarray _real_view(usm_ndarray ary): offset=offset_elems, order=('C' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else 'F') ) - r.flags_ |= (ary.flags_ & USM_ARRAY_WRITABLE) + r.flags_ = _copy_writable(r.flags_, ary.flags_) r.array_namespace_ = ary.array_namespace_ return r @@ -1257,7 +1263,7 @@ cdef usm_ndarray _imag_view(usm_ndarray ary): offset=offset_elems, order=('C' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else 'F') ) - r.flags_ |= (ary.flags_ & USM_ARRAY_WRITABLE) + r.flags_ = _copy_writable(r.flags_, ary.flags_) r.array_namespace_ = ary.array_namespace_ return r @@ -1277,7 +1283,7 @@ cdef usm_ndarray _transpose(usm_ndarray ary): order=('F' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else 'C'), offset=ary.get_offset() ) - r.flags_ |= (ary.flags_ & USM_ARRAY_WRITABLE) + r.flags_ = _copy_writable(r.flags_, ary.flags_) return r @@ -1294,7 +1300,7 @@ cdef usm_ndarray _m_transpose(usm_ndarray ary): order=('F' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else 'C'), offset=ary.get_offset() ) - r.flags_ |= (ary.flags_ & USM_ARRAY_WRITABLE) + r.flags_ = _copy_writable(r.flags_, ary.flags_) return r diff --git a/dpctl/tests/test_usm_ndarray_ctor.py b/dpctl/tests/test_usm_ndarray_ctor.py index 7c5765332b..51f78eb59c 100644 --- a/dpctl/tests/test_usm_ndarray_ctor.py +++ b/dpctl/tests/test_usm_ndarray_ctor.py @@ -129,6 +129,25 @@ def test_usm_ndarray_flags_bug_gh_1334(): assert r.flags["F"] and r.flags["C"] +def test_usm_ndarray_writable_flag_views(): + get_queue_or_skip() + a = dpt.arange(10, dtype="f4") + a.flags["W"] = False + + a.shape = (5, 2) + assert not a.flags.writable + assert not a.T.flags.writable + assert not a.mT.flags.writable + assert not a.real.flags.writable + assert not a[0:3].flags.writable + + a = dpt.arange(10, dtype="c8") + a.flags["W"] = False + + assert not a.real.flags.writable + assert not a.imag.flags.writable + + @pytest.mark.parametrize( "dtype", [