diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e5e77c677681..0cf99f815e90 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -98,6 +98,7 @@ repos: "-sn", # Don't display the score "--disable=import-error", "--disable=redefined-builtin", - "--disable=unused-wildcard-import" + "--disable=unused-wildcard-import", + "--class-naming-style=snake_case" ] files: '^dpnp/(dpnp_iface.*|fft|linalg)' diff --git a/doc/conf.py b/doc/conf.py index 081070a5e59b..0a7dd57a2f33 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -8,7 +8,13 @@ from sphinx.ext.autodoc import FunctionDocumenter -from dpnp.dpnp_algo.dpnp_elementwise_common import DPNPBinaryFunc, DPNPUnaryFunc +from dpnp.dpnp_algo.dpnp_elementwise_common import ( + DPNPBinaryFunc, + DPNPUnaryFunc, + binary_ufunc, + ufunc, + unary_ufunc, +) try: import comparison_generator @@ -202,7 +208,10 @@ # -- Options for todo extension ---------------------------------------------- def _can_document_member(member, *args, **kwargs): - if isinstance(member, (DPNPBinaryFunc, DPNPUnaryFunc)): + if isinstance( + member, + (DPNPBinaryFunc, DPNPUnaryFunc, ufunc, unary_ufunc, binary_ufunc), + ): return True return orig(member, *args, **kwargs) diff --git a/doc/reference/ufunc.rst b/doc/reference/ufunc.rst index 2dffca15e889..2acb004f03fc 100644 --- a/doc/reference/ufunc.rst +++ b/doc/reference/ufunc.rst @@ -6,6 +6,50 @@ Universal Functions (ufunc) .. https://docs.scipy.org/doc/numpy/reference/ufuncs.html DPNP provides universal functions (a.k.a. ufuncs) to support various element-wise operations. +DPNP ufunc supports following features of NumPy’s one: + +- Broadcasting +- Output type determination +- Casting rules + +ufuncs +------ +.. autosummary:: + :toctree: generated/ + + dpnp.ufunc + +Attributes +~~~~~~~~~~ + +There are some informational attributes that universal functions +possess. None of the attributes can be set. + +============ ================================================================= +**__doc__** A docstring for each ufunc. The first part of the docstring is + dynamically generated from the number of outputs, the name, and + the number of inputs. The second part of the docstring is + provided at creation time and stored with the ufunc. + +**__name__** The name of the ufunc. +============ ================================================================= + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + dpnp.ufunc.nin + dpnp.ufunc.nout + dpnp.ufunc.nargs + dpnp.ufunc.types + dpnp.ufunc.ntypes + +Methods +~~~~~~~ +.. autosummary:: + :toctree: generated/ + + dpnp.ufunc.outer Available ufuncs ---------------- diff --git a/dpnp/dpnp_algo/dpnp_elementwise_common.py b/dpnp/dpnp_algo/dpnp_elementwise_common.py index 374981a63031..121cc9925fb4 100644 --- a/dpnp/dpnp_algo/dpnp_elementwise_common.py +++ b/dpnp/dpnp_algo/dpnp_elementwise_common.py @@ -31,6 +31,7 @@ ) import dpnp +import dpnp.dpnp_algo.dpnp_elementwise_docs as ufunc_docs from dpnp.dpnp_array import dpnp_array __all__ = [ @@ -43,9 +44,459 @@ "DPNPReal", "DPNPRound", "DPNPUnaryFunc", + "ufunc", + "unary_ufunc", + "binary_ufunc", ] +import dpctl.tensor._tensor_elementwise_impl as ti + +import dpnp.backend.extensions.vm._vm_impl as vmi + + +class ufunc: + """ + ufunc() + + Functions that operate element by element on whole arrays. + + Calling ufuncs + -------------- + op(*x[, out], **kwargs) + + Apply `op` to the arguments `*x` elementwise, broadcasting the arguments. + + Parameters + ---------- + *x : {dpnp.ndarray, usm_ndarray} + Input arrays. + out : {None, dpnp.ndarray, usm_ndarray}, optional + Output array to populate. + Array must have the correct shape and the expected data type. + order : {None, "C", "F", "A", "K"}, optional + Memory layout of the newly output array, Cannot be provided + together with `out`. Default: ``"K"``. + dtype : {None, dtype}, optional + If provided, the destination array will have this dtype. Cannot be + provided together with `out`. Default: ``None``. + casting : {"no", "equiv", "safe", "same_kind", "unsafe"}, optional + Controls what kind of data casting may occur. Cannot be provided + together with `out`. Default: ``"safe"``. + + Limitations + ----------- + Keyword arguments `where` and `subok` are supported with their default values. + Other keyword arguments is currently unsupported. + Otherwise ``NotImplementedError`` exception will be raised. + + """ + + def __init__( + self, + name, + nin, + nout=1, + func=None, + to_usm_astype=None, + ): + self.nin_ = nin + self.nout_ = nout + self.func = func + self.to_usm_astype = to_usm_astype + self.__name__ = name + self.__doc__ = getattr(ufunc_docs, name + "_docstring") + + def __call__( + self, + *args, + out=None, + where=True, + casting="same_kind", + order="K", + dtype=None, + subok=True, + **kwargs, + ): + dpnp.check_supported_arrays_type( + *args, scalar_type=True, all_scalars=False + ) + if kwargs: + raise NotImplementedError( + f"Requested function={self.__name__} with kwargs={kwargs} " + "isn't currently supported." + ) + if where is not True: + raise NotImplementedError( + f"Requested function={self.__name__} with where={where} " + "isn't currently supported." + ) + if subok is not True: + raise NotImplementedError( + f"Requested function={self.__name__} with subok={subok} " + "isn't currently supported." + ) + if (dtype is not None or casting != "same_kind") and out is not None: + raise TypeError( + f"Requested function={self.__name__} only takes `out` or " + "`dtype` as an argument, but both were provided." + ) + if order is None: + order = "K" + elif order in "afkcAFKC": + order = order.upper() + else: + raise ValueError( + f"order must be one of 'C', 'F', 'A', or 'K' (got '{order}')" + ) + + astype_usm_args = self.to_usm_astype(*args, dtype, casting) + + out_usm = None if out is None else dpnp.get_usm_ndarray(out) + + res_usm = self.func.__call__(*astype_usm_args, out=out_usm, order=order) + + if out is not None and isinstance(out, dpnp_array): + return out + return dpnp_array._create_from_usm_ndarray(res_usm) + + @property + def nin(self): + """ + Returns the number of arguments treated as inputs. + + Examples + -------- + >>> import dpnp as np + >>> np.add.nin + 2 + >>> np.multiply.nin + 2 + >>> np.power.nin + 2 + >>> np.exp.nin + 1 + + """ + + return self.nin_ + + @property + def nout(self): + """ + Returns the number of arguments treated as outputs. + + Examples + -------- + >>> import dpnp as np + >>> np.add.nin + 1 + >>> np.multiply.nin + 1 + >>> np.power.nin + 1 + >>> np.exp.nin + 1 + + """ + + return self.nout_ + + @property + def nargs(self): + """ + Returns the number of arguments treated. + + Examples + -------- + >>> import dpnp as np + >>> np.add.nin + 3 + >>> np.multiply.nin + 3 + >>> np.power.nin + 3 + >>> np.exp.nin + 2 + + """ + + return self.nin_ + self.nout_ + + @property + def types(self): + """ + Returns information about types supported by implementation function, + using NumPy's character encoding for data types, e.g. + + Examples + -------- + >>> import dpnp as np + >>> np.add.types + ['??->?', 'bb->b', 'BB->B', 'hh->h', 'HH->H', 'ii->i', 'II->I', + 'll->l', 'LL->L', 'ee->e', 'ff->f', 'dd->d', 'FF->F', 'DD->D'] + + >>> np.multiply.types + ['??->?', 'bb->b', 'BB->B', 'hh->h', 'HH->H', 'ii->i', 'II->I', + 'll->l', 'LL->L', 'ee->e', 'ff->f', 'dd->d', 'FF->F', 'DD->D'] + + >>> np.power.types + ['bb->b', 'BB->B', 'hh->h', 'HH->H', 'ii->i', 'II->I', 'll->l', + 'LL->L', 'ee->e', 'ff->f', 'dd->d', 'FF->F', 'DD->D'] + + >>> np.exp.types + ['e->e', 'f->f', 'd->d', 'F->F', 'D->D'] + + >>> np.remainder.types + ['bb->b', 'BB->B', 'hh->h', 'HH->H', 'ii->i', 'II->I', 'll->l', + 'LL->L', 'ee->e', 'ff->f', 'dd->d'] + + """ + + return self.func.types + + @property + def ntypes(self): + """ + The number of types. + + Examples + -------- + >>> import dpnp as np + >>> np.add.ntypes + 14 + >>> np.multiply.ntypes + 14 + >>> np.power.ntypes + 13 + >>> np.exp.ntypes + 5 + >>> np.remainder.ntypes + 11 + + """ + + return len(self.func.types) + + def outer( + self, + x1, + x2, + out=None, + where=True, + order="K", + dtype=None, + subok=True, + **kwargs, + ): + """ + Apply the ufunc op to all pairs (a, b) with a in A and b in B. + + Parameters + ---------- + x1 : {dpnp.ndarray, usm_ndarray} + First input array. + x2 : {dpnp.ndarray, usm_ndarray} + Second input array. + out : {None, dpnp.ndarray, usm_ndarray}, optional + Output array to populate. + Array must have the correct shape and the expected data type. + **kwargs + For other keyword-only arguments, see the :obj:`dpnp.ufunc`. + + Returns + ------- + out : dpnp.ndarray + Output array. The data type of the returned array is determined by + the Type Promotion Rules. + + Limitations + ----------- + Parameters `where` and `subok` are supported with their default values. + Keyword argument `kwargs` is currently unsupported. + Otherwise ``NotImplementedError`` exception will be raised. + + See also + -------- + :obj:`dpnp.outer` : A less powerful version of dpnp.multiply.outer + that ravels all inputs to 1D. This exists primarily + for compatibility with old code. + + :obj:`dpnp.tensordot` : dpnp.tensordot(a, b, axes=((), ())) and + dpnp.multiply.outer(a, b) behave same for all + dimensions of a and b. + + Examples + -------- + >>> import dpnp as np + >>> A = np.array([1, 2, 3]) + >>> B = np.array([4, 5, 6]) + >>> np.multiply.outer(A, B) + array([[ 4, 5, 6], + [ 8, 10, 12], + [12, 15, 18]]) + + A multi-dimensional example: + + >>> A = np.array([[1, 2, 3], [4, 5, 6]]) + >>> A.shape + (2, 3) + >>> B = np.array([[1, 2, 3, 4]]) + >>> B.shape + (1, 4) + >>> C = np.multiply.outer(A, B) + >>> C.shape; C + (2, 3, 1, 4) + array([[[[ 1, 2, 3, 4]], + [[ 2, 4, 6, 8]], + [[ 3, 6, 9, 12]]], + [[[ 4, 8, 12, 16]], + [[ 5, 10, 15, 20]], + [[ 6, 12, 18, 24]]]]) + + """ + + dpnp.check_supported_arrays_type( + x1, x2, scalar_type=True, all_scalars=False + ) + if dpnp.isscalar(x1) or dpnp.isscalar(x2): + _x1 = x1 + _x2 = x2 + else: + _x1 = x1[(Ellipsis,) + (None,) * x2.ndim] + _x2 = x2[(None,) * x1.ndim + (Ellipsis,)] + return self.__call__( + _x1, + _x2, + out=out, + where=where, + order=order, + dtype=dtype, + subok=subok, + **kwargs, + ) + + +class unary_ufunc(ufunc): + def __init__( + self, + name, + mkl_call=False, + acceptance_fn=False, + ): + def _to_usm_astype(x, dtype, casting): + if dtype is not None: + x = dpnp.astype(x, dtype=dtype, casting=casting, copy=False) + x_usm = dpnp.get_usm_ndarray(x) + return (x_usm,) + + _name = "_" + name + + dpt_result_type = getattr(ti, _name + "_result_type") + dpt_impl_fn = getattr(ti, _name) + + def _call_func(src, dst, sycl_queue, depends=None): + """ + A callback to register in UnaryElementwiseFunc class of + dpctl.tensor + """ + + if depends is None: + depends = [] + + if mkl_call is True: + mkl_fn_to_call = getattr(vmi, "_mkl" + _name + "_to_call") + mkl_impl_fn = getattr(vmi, _name) + if mkl_fn_to_call is not None and mkl_fn_to_call( + sycl_queue, src, dst + ): + # call pybind11 extension for unary function from OneMKL VM + return mkl_impl_fn(sycl_queue, src, dst, depends) + return dpt_impl_fn(src, dst, sycl_queue, depends) + + func = UnaryElementwiseFunc( + name, + dpt_result_type, + _call_func, + self.__doc__, + acceptance_fn=acceptance_fn, + ) + + super().__init__(name, nin=1, func=func, to_usm_astype=_to_usm_astype) + + +class binary_ufunc(ufunc): + def __init__( + self, + name, + mkl_call=False, + inplace=False, + acceptance_fn=False, + ): + def _to_usm_astype(x1, x2, dtype, casting): + if dtype is not None: + if dpnp.isscalar(x1): + x1 = dpnp.asarray(x1, dtype=dtype) + x2 = dpnp.astype( + x2, dtype=dtype, casting=casting, copy=False + ) + elif dpnp.isscalar(x2): + x1 = dpnp.astype( + x1, dtype=dtype, casting=casting, copy=False + ) + x2 = dpnp.asarray(x2, dtype=dtype) + else: + x1 = dpnp.astype( + x1, dtype=dtype, casting=casting, copy=False + ) + x2 = dpnp.astype( + x2, dtype=dtype, casting=casting, copy=False + ) + x1_usm = dpnp.get_usm_ndarray_or_scalar(x1) + x2_usm = dpnp.get_usm_ndarray_or_scalar(x2) + return x1_usm, x2_usm + + _name = "_" + name + + dpt_result_type = getattr(ti, _name + "_result_type") + dpt_impl_fn = getattr(ti, _name) + + if inplace is True: + binary_inplace_fn = getattr(ti, _name + "_inplace") + else: + binary_inplace_fn = None + + def _call_func(src1, src2, dst, sycl_queue, depends=None): + """ + A callback to register in UnaryElementwiseFunc class of + dpctl.tensor + """ + + if depends is None: + depends = [] + + if mkl_call is True: + mkl_fn_to_call = getattr(vmi, "_mkl" + _name + "_to_call") + mkl_impl_fn = getattr(vmi, _name) + if mkl_fn_to_call is not None and mkl_fn_to_call( + sycl_queue, src1, src2, dst + ): + # call pybind11 extension for binary function from OneMKL VM + return mkl_impl_fn(sycl_queue, src1, src2, dst, depends) + return dpt_impl_fn(src1, src2, dst, sycl_queue, depends) + + func = BinaryElementwiseFunc( + name, + dpt_result_type, + _call_func, + self.__doc__, + binary_inplace_fn, + acceptance_fn=acceptance_fn, + ) + + super().__init__(name, nin=2, func=func, to_usm_astype=_to_usm_astype) + + class DPNPUnaryFunc(UnaryElementwiseFunc): """ Class that implements unary element-wise functions. diff --git a/dpnp/dpnp_algo/dpnp_elementwise_docs.py b/dpnp/dpnp_algo/dpnp_elementwise_docs.py new file mode 100644 index 000000000000..6d8195bb283f --- /dev/null +++ b/dpnp/dpnp_algo/dpnp_elementwise_docs.py @@ -0,0 +1,166 @@ +_unary_doc_template = """ +dpnp.%s(x, out=None, order='K', dtype=None, casting="same_kind", **kwargs) + +%s + +For full documentation refer to :obj:`numpy.%s`. + +Parameters +---------- +x : {dpnp.ndarray, usm_ndarray} + Input arrays, expected to have %s data type. +out : {None, dpnp.ndarray, usm_ndarray}, optional + Output array to populate. + Array must have the correct shape and the expected data type. +order : {None, "C", "F", "A", "K"}, optional + Memory layout of the newly output array, Cannot be provided + together with `out`. Default: ``"K"``. +dtype : {None, dtype}, optional + If provided, the destination array will have this dtype. Cannot be + provided together with `out`. Default: ``None``. +casting : {"no", "equiv", "safe", "same_kind", "unsafe"}, optional + Controls what kind of data casting may occur. Cannot be provided + together with `out`. Default: ``"safe"``. + +Returns +------- +out : dpnp.ndarray +%s + +Limitations +----------- +Keyword arguments `where` and `subok` are supported with their default values. +Other keyword arguments is currently unsupported. +Otherwise ``NotImplementedError`` exception will be raised. + +%s +""" + +_binary_doc_template = """ +dpnp.%s(x1, x2, out=None, order='K', dtype=None, casting="same_kind", **kwargs) + +%s + +For full documentation refer to :obj:`numpy.%s`. + +Parameters +---------- +x1, x2 : {dpnp.ndarray, usm_ndarray} + Input arrays, expected to have %s data type. +out : {None, dpnp.ndarray, usm_ndarray}, optional + Output array to populate. + Array must have the correct shape and the expected data type. +order : {None, "C", "F", "A", "K"}, optional + Memory layout of the newly output array, Cannot be provided + together with `out`. Default: ``"K"``. +dtype : {None, dtype}, optional + If provided, the destination array will have this dtype. Cannot be + provided together with `out`. Default: ``None``. +casting : {"no", "equiv", "safe", "same_kind", "unsafe"}, optional + Controls what kind of data casting may occur. Cannot be provided + together with `out`. Default: ``"safe"``. + +Returns +------- +out : dpnp.ndarray +%s + +Limitations +----------- +Keyword arguments `where` and `subok` are supported with their default values. +Other keyword arguments is currently unsupported. +Otherwise ``NotImplementedError`` exception will be raised. + +%s +""" + + +name = "absolute" +dtypes = "numeric" +summary = """ +Calculates the absolute value for each element `x_i` of input array `x`. +""" +returns = """ + An array containing the element-wise absolute values. + For complex input, the absolute value is its magnitude. + If `x` has a real-valued data type, the returned array has the + same data type as `x`. If `x` has a complex floating-point data type, + the returned array has a real-valued floating-point data type whose + precision matches the precision of `x`. +""" +other = """ +See Also +-------- +:obj:`dpnp.fabs` : Calculate the absolute value element-wise excluding complex types. + +Notes +----- +``dpnp.abs`` is a shorthand for this function. + +Examples +-------- +>>> import dpnp as np +>>> a = np.array([-1.2, 1.2]) +>>> np.absolute(a) +array([1.2, 1.2]) + +>>> a = np.array(1.2 + 1j) +>>> np.absolute(a) +array(1.5620499351813308) +""" +abs_docstring = _unary_doc_template % ( + name, + summary, + name, + dtypes, + returns, + other, +) + + +name = "add" +dtypes = "numeric" +summary = """ +Calculates the sum for each element `x1_i` of the input array `x1` with the +respective element `x2_i` of the input array `x2`. +""" +returns = """ + An array containing the element-wise sums. The data type of the returned + array is determined by the Type Promotion Rules. +""" +other = """ +Notes +----- +Equivalent to `x1` + `x2` in terms of array broadcasting. + +Examples +-------- +>>> import dpnp as np +>>> a = np.array([1, 2, 3]) +>>> b = np.array([1, 2, 3]) +>>> np.add(a, b) +array([2, 4, 6]) + +>>> x1 = np.arange(9.0).reshape((3, 3)) +>>> x2 = np.arange(3.0) +>>> np.add(x1, x2) +array([[ 0., 2., 4.], + [ 3., 5., 7.], + [ 6., 8., 10.]]) + +The ``+`` operator can be used as a shorthand for ``add`` on +:class:`dpnp.ndarray`. + +>>> x1 + x2 +array([[ 0., 2., 4.], + [ 3., 5., 7.], + [ 6., 8., 10.]]) +""" +add_docstring = _binary_doc_template % ( + name, + summary, + name, + dtypes, + returns, + other, +) diff --git a/dpnp/dpnp_iface_mathematical.py b/dpnp/dpnp_iface_mathematical.py index b0d0c7b61237..f935e1a18f0c 100644 --- a/dpnp/dpnp_iface_mathematical.py +++ b/dpnp/dpnp_iface_mathematical.py @@ -77,12 +77,16 @@ acceptance_fn_positive, acceptance_fn_sign, acceptance_fn_subtract, + binary_ufunc, + unary_ufunc, ) from .dpnp_array import dpnp_array from .dpnp_utils import call_origin, get_usm_allocations from .dpnp_utils.dpnp_utils_linearalgebra import dpnp_cross from .dpnp_utils.dpnp_utils_reduction import dpnp_wrap_reduction_call +ufunc = dpnp.dpnp_algo.dpnp_elementwise_common.ufunc + __all__ = [ "abs", "absolute", @@ -130,6 +134,7 @@ "trapz", "true_divide", "trunc", + "ufunc", ] @@ -330,141 +335,14 @@ def _gradient_num_diff_edges( a * f[tuple(slice2)] + b * f[tuple(slice3)] + c * f[tuple(slice4)] ) - -_ABS_DOCSTRING = """ -Calculates the absolute value for each element `x_i` of input array `x`. - -For full documentation refer to :obj:`numpy.absolute`. - -Parameters ----------- -x : {dpnp.ndarray, usm_ndarray} - Input array, expected to have numeric data type. -out : {None, dpnp.ndarray}, optional - Output array to populate. - Array must have the correct shape and the expected data type. -order : {"C", "F", "A", "K"}, optional - Memory layout of the newly output array, if parameter `out` is ``None``. - Default: "K". - -Returns -------- -out : dpnp.ndarray - An array containing the element-wise absolute values. - For complex input, the absolute value is its magnitude. - If `x` has a real-valued data type, the returned array has the - same data type as `x`. If `x` has a complex floating-point data type, - the returned array has a real-valued floating-point data type whose - precision matches the precision of `x`. - -Limitations ------------ -Parameters `where` and `subok` are supported with their default values. -Keyword argument `kwargs` is currently unsupported. -Otherwise ``NotImplementedError`` exception will be raised. - -See Also --------- -:obj:`dpnp.fabs` : Calculate the absolute value element-wise excluding complex types. - -Notes ------ -``dpnp.abs`` is a shorthand for this function. - -Examples --------- ->>> import dpnp as np ->>> a = np.array([-1.2, 1.2]) ->>> np.absolute(a) -array([1.2, 1.2]) - ->>> a = np.array(1.2 + 1j) ->>> np.absolute(a) -array(1.5620499351813308) -""" - -absolute = DPNPUnaryFunc( - "abs", - ti._abs_result_type, - ti._abs, - _ABS_DOCSTRING, - mkl_fn_to_call=vmi._mkl_abs_to_call, - mkl_impl_fn=vmi._abs, -) + +absolute = unary_ufunc("abs", mkl_call=True) abs = absolute -_ADD_DOCSTRING = """ -Calculates the sum for each element `x1_i` of the input array `x1` with -the respective element `x2_i` of the input array `x2`. - -For full documentation refer to :obj:`numpy.add`. - -Parameters ----------- -x1 : {dpnp.ndarray, usm_ndarray} - First input array, expected to have numeric data type. -x2 : {dpnp.ndarray, usm_ndarray} - Second input array, also expected to have numeric data type. -out : {None, dpnp.ndarray}, optional - Output array to populate. - Array must have the correct shape and the expected data type. -order : {"C", "F", "A", "K"}, optional - Memory layout of the newly output array, if parameter `out` is ``None``. - Default: "K". - -Returns -------- -out : dpnp.ndarray - An array containing the element-wise sums. The data type of the - returned array is determined by the Type Promotion Rules. - -Limitations ------------ -Parameters `where` and `subok` are supported with their default values. -Keyword argument `kwargs` is currently unsupported. -Otherwise ``NotImplementedError`` exception will be raised. - -Notes ------ -Equivalent to `x1` + `x2` in terms of array broadcasting. - -Examples --------- ->>> import dpnp as np ->>> a = np.array([1, 2, 3]) ->>> b = np.array([1, 2, 3]) ->>> np.add(a, b) -array([2, 4, 6]) - ->>> x1 = np.arange(9.0).reshape((3, 3)) ->>> x2 = np.arange(3.0) ->>> np.add(x1, x2) -array([[ 0., 2., 4.], - [ 3., 5., 7.], - [ 6., 8., 10.]]) - -The ``+`` operator can be used as a shorthand for ``add`` on -:class:`dpnp.ndarray`. - ->>> x1 + x2 -array([[ 0., 2., 4.], - [ 3., 5., 7.], - [ 6., 8., 10.]]) -""" - - -add = DPNPBinaryFunc( - "add", - ti._add_result_type, - ti._add, - _ADD_DOCSTRING, - mkl_fn_to_call=vmi._mkl_add_to_call, - mkl_impl_fn=vmi._add, - binary_inplace_fn=ti._add_inplace, -) +add = binary_ufunc("add", mkl_call=True, inplace=True) _ANGLE_DOCSTRING = """