diff --git a/dpnp/backend/extensions/ufunc/CMakeLists.txt b/dpnp/backend/extensions/ufunc/CMakeLists.txt index d363910f74df..bbc6881ffcd0 100644 --- a/dpnp/backend/extensions/ufunc/CMakeLists.txt +++ b/dpnp/backend/extensions/ufunc/CMakeLists.txt @@ -36,6 +36,7 @@ set(_elementwise_sources ${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/gcd.cpp ${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/heaviside.cpp ${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/i0.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/interpolate.cpp ${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/lcm.cpp ${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/ldexp.cpp ${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/logaddexp2.cpp @@ -69,6 +70,7 @@ endif() set_target_properties(${python_module_name} PROPERTIES CMAKE_POSITION_INDEPENDENT_CODE ON) target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../) +target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../common) target_include_directories(${python_module_name} PUBLIC ${Dpctl_INCLUDE_DIR}) target_include_directories(${python_module_name} PUBLIC ${Dpctl_TENSOR_INCLUDE_DIR}) diff --git a/dpnp/backend/extensions/ufunc/elementwise_functions/common.cpp b/dpnp/backend/extensions/ufunc/elementwise_functions/common.cpp index 8ff89a1b03b6..c6dd3e038eb1 100644 --- a/dpnp/backend/extensions/ufunc/elementwise_functions/common.cpp +++ b/dpnp/backend/extensions/ufunc/elementwise_functions/common.cpp @@ -36,6 +36,7 @@ #include "gcd.hpp" #include "heaviside.hpp" #include "i0.hpp" +#include "interpolate.hpp" #include "lcm.hpp" #include "ldexp.hpp" #include "logaddexp2.hpp" @@ -64,6 +65,7 @@ void init_elementwise_functions(py::module_ m) init_gcd(m); init_heaviside(m); init_i0(m); + init_interpolate(m); init_lcm(m); init_ldexp(m); init_logaddexp2(m); diff --git a/dpnp/backend/extensions/ufunc/elementwise_functions/interpolate.cpp b/dpnp/backend/extensions/ufunc/elementwise_functions/interpolate.cpp new file mode 100644 index 000000000000..784cef224548 --- /dev/null +++ b/dpnp/backend/extensions/ufunc/elementwise_functions/interpolate.cpp @@ -0,0 +1,300 @@ +//***************************************************************************** +// Copyright (c) 2025, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#include +#include + +#include "dpctl4pybind11.hpp" +#include +#include + +// dpctl tensor headers +#include "utils/output_validation.hpp" +#include "utils/type_dispatch.hpp" + +#include "kernels/elementwise_functions/interpolate.hpp" + +#include "ext/validation_utils.hpp" + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; + +using ext::validation::array_names; +using ext::validation::array_ptr; +using ext::validation::common_checks; + +namespace dpnp::extensions::ufunc +{ + +namespace impl +{ + +template +struct value_type_of +{ + using type = T; +}; + +template +struct value_type_of> +{ + using type = T; +}; + +template +using value_type_of_t = typename value_type_of::type; + +typedef sycl::event (*interpolate_fn_ptr_t)(sycl::queue &, + const void *, // x + const void *, // idx + const void *, // xp + const void *, // fp + const void *, // left + const void *, // right + void *, // out + std::size_t, // n + std::size_t, // xp_size + const std::vector &); + +template +sycl::event interpolate_call(sycl::queue &exec_q, + const void *vx, + const void *vidx, + const void *vxp, + const void *vfp, + const void *vleft, + const void *vright, + void *vout, + std::size_t n, + std::size_t xp_size, + const std::vector &depends) +{ + using dpctl::tensor::type_utils::is_complex_v; + using TCoord = std::conditional_t, value_type_of_t, T>; + + const TCoord *x = static_cast(vx); + const std::int64_t *idx = static_cast(vidx); + const TCoord *xp = static_cast(vxp); + const T *fp = static_cast(vfp); + const T *left = static_cast(vleft); + const T *right = static_cast(vright); + T *out = static_cast(vout); + + using dpnp::kernels::interpolate::interpolate_impl; + sycl::event interpolate_ev = interpolate_impl( + exec_q, x, idx, xp, fp, left, right, out, n, xp_size, depends); + + return interpolate_ev; +} + +interpolate_fn_ptr_t interpolate_dispatch_vector[td_ns::num_types]; + +void common_interpolate_checks( + const dpctl::tensor::usm_ndarray &x, + const dpctl::tensor::usm_ndarray &idx, + const dpctl::tensor::usm_ndarray &xp, + const dpctl::tensor::usm_ndarray &fp, + const dpctl::tensor::usm_ndarray &out, + const std::optional &left, + const std::optional &right) +{ + array_names names = {{&x, "x"}, {&xp, "xp"}, {&fp, "fp"}, {&out, "out"}}; + + auto array_types = td_ns::usm_ndarray_types(); + int x_type_id = array_types.typenum_to_lookup_id(x.get_typenum()); + int xp_type_id = array_types.typenum_to_lookup_id(xp.get_typenum()); + int fp_type_id = array_types.typenum_to_lookup_id(fp.get_typenum()); + int out_type_id = array_types.typenum_to_lookup_id(out.get_typenum()); + + if (x_type_id != xp_type_id) { + throw py::value_error("x and xp must have the same dtype"); + } + if (fp_type_id != out_type_id) { + throw py::value_error("fp and out must have the same dtype"); + } + + if (left) { + const auto &l = left.value(); + names.insert({&l, "left"}); + if (l.get_ndim() != 0) { + throw py::value_error("left must be a zero-dimensional array"); + } + + int left_type_id = array_types.typenum_to_lookup_id(l.get_typenum()); + if (left_type_id != fp_type_id) { + throw py::value_error( + "left must have the same dtype as fp and out"); + } + } + + if (right) { + const auto &r = right.value(); + names.insert({&r, "right"}); + if (r.get_ndim() != 0) { + throw py::value_error("right must be a zero-dimensional array"); + } + + int right_type_id = array_types.typenum_to_lookup_id(r.get_typenum()); + if (right_type_id != fp_type_id) { + throw py::value_error( + "right must have the same dtype as fp and out"); + } + } + + common_checks({&x, &xp, &fp, left ? &left.value() : nullptr, + right ? &right.value() : nullptr}, + {&out}, names); + + if (x.get_ndim() != 1 || xp.get_ndim() != 1 || fp.get_ndim() != 1 || + idx.get_ndim() != 1 || out.get_ndim() != 1) + { + throw py::value_error("All arrays must be one-dimensional"); + } + + if (xp.get_size() != fp.get_size()) { + throw py::value_error("xp and fp must have the same size"); + } + + if (x.get_size() != out.get_size() || x.get_size() != idx.get_size()) { + throw py::value_error("x, idx, and out must have the same size"); + } +} + +std::pair + py_interpolate(const dpctl::tensor::usm_ndarray &x, + const dpctl::tensor::usm_ndarray &idx, + const dpctl::tensor::usm_ndarray &xp, + const dpctl::tensor::usm_ndarray &fp, + std::optional &left, + std::optional &right, + dpctl::tensor::usm_ndarray &out, + sycl::queue &exec_q, + const std::vector &depends) +{ + if (x.get_size() == 0) { + return {sycl::event(), sycl::event()}; + } + + common_interpolate_checks(x, idx, xp, fp, out, left, right); + + int out_typenum = out.get_typenum(); + + auto array_types = td_ns::usm_ndarray_types(); + int out_type_id = array_types.typenum_to_lookup_id(out_typenum); + + auto fn = interpolate_dispatch_vector[out_type_id]; + if (!fn) { + throw py::type_error("Unsupported dtype"); + } + + std::size_t n = x.get_size(); + std::size_t xp_size = xp.get_size(); + + void *left_ptr = left ? left.value().get_data() : nullptr; + void *right_ptr = right ? right.value().get_data() : nullptr; + + sycl::event ev = + fn(exec_q, x.get_data(), idx.get_data(), xp.get_data(), fp.get_data(), + left_ptr, right_ptr, out.get_data(), n, xp_size, depends); + + sycl::event args_ev; + + if (left && right) { + args_ev = dpctl::utils::keep_args_alive( + exec_q, {x, idx, xp, fp, out, left.value(), right.value()}, {ev}); + } + else if (left) { + args_ev = dpctl::utils::keep_args_alive( + exec_q, {x, idx, xp, fp, out, left.value()}, {ev}); + } + else if (right) { + args_ev = dpctl::utils::keep_args_alive( + exec_q, {x, idx, xp, fp, out, right.value()}, {ev}); + } + else { + args_ev = + dpctl::utils::keep_args_alive(exec_q, {x, idx, xp, fp, out}, {ev}); + } + + return std::make_pair(args_ev, ev); +} + +/** + * @brief A factory to define pairs of supported types for which + * interpolate function is available. + * + * @tparam T Type of input vector `a` and of result vector `y`. + */ +template +struct InterpolateOutputType +{ + using value_type = typename std::disjunction< + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry>, + td_ns::TypeMapResultEntry>, + td_ns::DefaultResultEntry>::result_type; +}; + +template +struct InterpolateFactory +{ + fnT get() + { + if constexpr (std::is_same_v< + typename InterpolateOutputType::value_type, void>) + { + return nullptr; + } + else { + return interpolate_call; + } + } +}; + +void init_interpolate_dispatch_vectors() +{ + using namespace td_ns; + + DispatchVectorBuilder + dtb_interpolate; + dtb_interpolate.populate_dispatch_vector(interpolate_dispatch_vector); +} + +} // namespace impl + +void init_interpolate(py::module_ m) +{ + impl::init_interpolate_dispatch_vectors(); + + using impl::py_interpolate; + m.def("_interpolate", &py_interpolate, "", py::arg("x"), py::arg("idx"), + py::arg("xp"), py::arg("fp"), py::arg("left"), py::arg("right"), + py::arg("out"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); +} + +} // namespace dpnp::extensions::ufunc diff --git a/dpnp/backend/extensions/ufunc/elementwise_functions/interpolate.hpp b/dpnp/backend/extensions/ufunc/elementwise_functions/interpolate.hpp new file mode 100644 index 000000000000..4ae1cb2c8958 --- /dev/null +++ b/dpnp/backend/extensions/ufunc/elementwise_functions/interpolate.hpp @@ -0,0 +1,35 @@ +//***************************************************************************** +// Copyright (c) 2025, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#pragma once + +#include + +namespace py = pybind11; + +namespace dpnp::extensions::ufunc +{ +void init_interpolate(py::module_ m); +} // namespace dpnp::extensions::ufunc diff --git a/dpnp/backend/kernels/elementwise_functions/interpolate.hpp b/dpnp/backend/kernels/elementwise_functions/interpolate.hpp new file mode 100644 index 000000000000..76ad0946aa22 --- /dev/null +++ b/dpnp/backend/kernels/elementwise_functions/interpolate.hpp @@ -0,0 +1,67 @@ +#pragma once + +#include +#include + +#include "ext/common.hpp" +#include "utils/type_utils.hpp" + +namespace type_utils = dpctl::tensor::type_utils; + +using ext::common::IsNan; + +namespace dpnp::kernels::interpolate +{ +template +sycl::event interpolate_impl(sycl::queue &q, + const TCoord *x, + const std::int64_t *idx, + const TCoord *xp, + const TValue *fp, + const TValue *left, + const TValue *right, + TValue *out, + const std::size_t n, + const std::size_t xp_size, + const std::vector &depends) +{ + return q.submit([&](sycl::handler &h) { + h.depends_on(depends); + h.parallel_for(sycl::range<1>(n), [=](sycl::id<1> i) { + TValue left_val = left ? *left : fp[0]; + TValue right_val = right ? *right : fp[xp_size - 1]; + + TCoord x_val = x[i]; + std::int64_t x_idx = idx[i] - 1; + + if (IsNan::isnan(x_val)) { + out[i] = x_val; + } + else if (x_idx < 0) { + out[i] = left_val; + } + else if (x_val == xp[xp_size - 1]) { + out[i] = fp[xp_size - 1]; + } + else if (x_idx >= static_cast(xp_size - 1)) { + out[i] = right_val; + } + else { + TValue slope = + (fp[x_idx + 1] - fp[x_idx]) / (xp[x_idx + 1] - xp[x_idx]); + TValue res = slope * (x_val - xp[x_idx]) + fp[x_idx]; + + if (IsNan::isnan(res)) { + res = slope * (x_val - xp[x_idx + 1]) + fp[x_idx + 1]; + if (IsNan::isnan(res) && + (fp[x_idx] == fp[x_idx + 1])) { + res = fp[x_idx]; + } + } + out[i] = res; + } + }); + }); +} + +} // namespace dpnp::kernels::interpolate diff --git a/dpnp/dpnp_iface_mathematical.py b/dpnp/dpnp_iface_mathematical.py index 3613c9bffff6..5c5eb7b05244 100644 --- a/dpnp/dpnp_iface_mathematical.py +++ b/dpnp/dpnp_iface_mathematical.py @@ -111,8 +111,9 @@ "gcd", "gradient", "heaviside", - "imag", "i0", + "imag", + "interp", "lcm", "ldexp", "maximum", @@ -348,6 +349,40 @@ def _process_ediff1d_args(arg, arg_name, ary_dtype, ary_sycl_queue, usm_type): return arg, usm_type +def _validate_interp_param(param, name, exec_q, usm_type, dtype=None): + """ + Validate and convert optional parameters for interpolation. + + Returns a USM array or None if the input is None. + """ + if param is None: + return None + + if dpnp.is_supported_array_type(param): + if param.ndim != 0: + raise ValueError( + f"a {name} value must be 0-dimensional, " + f"but got {param.ndim}-dim" + ) + if dpu.get_execution_queue([exec_q, param.sycl_queue]) is None: + raise ValueError( + "input arrays and {name} must be on the same SYCL queue" + ) + if dtype is not None: + param = param.astype(dtype) + return param.get_array() + + if dpnp.isscalar(param): + return dpt.asarray( + param, dtype=dtype, sycl_queue=exec_q, usm_type=usm_type + ) + + raise TypeError( + f"a {name} value must be a scalar or 0-d supported array, " + f"but got {type(param)}" + ) + + _ABS_DOCSTRING = """ Calculates the absolute value for each element :math:`x_i` of input array `x`. @@ -2742,6 +2777,180 @@ def gradient(f, *varargs, axis=None, edge_order=1): ) +def interp(x, xp, fp, left=None, right=None, period=None): + """ + One-dimensional linear interpolation. + + Returns the one-dimensional piecewise linear interpolant to a function + with given discrete data points (`xp`, `fp`), evaluated at `x`. + + For full documentation refer to :obj:`numpy.interp`. + + Parameters + ---------- + x : {dpnp.ndarray, usm_ndarray} + Input 1-D array. The x-coordinates at which to evaluate + the interpolated values. + + xp : {dpnp.ndarray, usm_ndarray} + Input 1-D array. The x-coordinates of the data points, + must be increasing if argument `period` is not specified. + Otherwise, `xp` is internally sorted after normalizing + the periodic boundaries with ``xp = xp % period``. + + fp : {dpnp.ndarray, usm_ndarray} + Input 1-D array. The y-coordinates of the data points, + same length as `xp`. + + left : {None, scalar, dpnp.ndarray, usm_ndarray}, optional + Value to return for `x < xp[0]`. + + Default: ``fp[0]``. + + right : {None, scalar, dpnp.ndarray, usm_ndarray}, optional + Value to return for `x > xp[-1]`. + + Default: ``fp[-1]``. + + period : {None, scalar, dpnp.ndarray, usm_ndarray}, optional + A period for the x-coordinates. This parameter allows the proper + interpolation of angular x-coordinates. Parameters `left` and `right` + are ignored if `period` is specified. + + Default: ``None``. + + Returns + ------- + y : {dpnp.ndarray, usm_ndarray} + The interpolated values, same shape as `x`. + + + Warnings + -------- + The x-coordinate sequence is expected to be increasing, but this is not + explicitly enforced. However, if the sequence `xp` is non-increasing, + interpolation results are meaningless. + + Note that, since NaN is unsortable, `xp` also cannot contain NaNs. + + A simple check for `xp` being strictly increasing is:: + + import dpnp as np + np.all(np.diff(xp) > 0) + + Examples + -------- + >>> import dpnp as np + >>> xp = np.array([1, 2, 3]) + >>> fp = np.array([3 ,2 ,0]) + >>> x = np.array([2.5]) + >>> np.interp(x, xp, fp) + array([1.]) + >>> x = np.array([0, 1, 1.5, 2.72, 3.14]) + >>> np.interp(x, xp, fp) + array([3. , 3. , 2.5 , 0.56, 0. ]) + >>> x = np.array([3.14]) + >>> UNDEF = -99.0 + >>> np.interp(x, xp, fp, right=UNDEF) + array([-99.]) + + Interpolation with periodic x-coordinates: + + >>> x = np.array([-180, -170, -185, 185, -10, -5, 0, 365]) + >>> xp = np.array([190, -190, 350, -350]) + >>> fp = np.array([5, 10, 3, 4]) + >>> np.interp(x, xp, fp, period=360) + array([7.5 , 5. , 8.75, 6.25, 3. , 3.25, 3.5 , 3.75]) + + Complex interpolation: + + >>> x = np.array([1.5, 4.0]) + >>> xp = np.array([2,3,5]) + >>> fp = np.array([1.0j, 0, 2+3j]) + >>> np.interp(x, xp, fp) + array([0.+1.j , 1.+1.5j]) + + """ + + dpnp.check_supported_arrays_type(x, xp, fp) + + if xp.ndim != 1 or fp.ndim != 1: + raise ValueError("xp and fp must be 1D arrays") + if xp.size != fp.size: + raise ValueError("fp and xp are not of the same length") + if xp.size == 0: + raise ValueError("array of sample points is empty") + + usm_type, exec_q = get_usm_allocations([x, xp, fp]) + + x_dtype = dpnp.common_type(x, xp) + x_float_type = dpnp.default_float_type(exec_q) + + if not dpnp.can_cast(x_dtype, x_float_type): + raise TypeError( + "Cannot cast array data from" + f" {x_dtype} to {x_float_type} " + "according to the rule 'safe'" + ) + + x = dpnp.asarray(x, dtype=x_float_type, order="C") + xp = dpnp.asarray(xp, dtype=x_float_type, order="C") + + out_dtype = dpnp.common_type(x, xp, fp) + + fp = dpnp.asarray(fp, dtype=out_dtype, order="C") + + if period is not None: + period = _validate_interp_param(period, "period", exec_q, usm_type) + if period == 0: + raise ValueError("period must be a non-zero value") + period = dpnp.abs(period) + + # left/right are ignored when period is specified + left = None + right = None + + # normalizing periodic boundaries + x %= period + xp %= period + asort_xp = dpnp.argsort(xp) + xp = xp[asort_xp] + fp = fp[asort_xp] + xp = dpnp.concatenate((xp[-1:] - period, xp, xp[0:1] + period)) + fp = dpnp.concatenate((fp[-1:], fp, fp[0:1])) + assert xp.flags.c_contiguous + assert fp.flags.c_contiguous + + idx = dpnp.searchsorted(xp, x, side="right") + left_usm = _validate_interp_param(left, "left", exec_q, usm_type, fp.dtype) + right_usm = _validate_interp_param( + right, "right", exec_q, usm_type, fp.dtype + ) + + usm_type, exec_q = get_usm_allocations( + [x, xp, fp, period, left_usm, right_usm] + ) + output = dpnp.empty( + x.shape, dtype=out_dtype, sycl_queue=exec_q, usm_type=usm_type + ) + + _manager = dpu.SequentialOrderManager[exec_q] + mem_ev, ht_ev = ufi._interpolate( + x.get_array(), + idx.get_array(), + xp.get_array(), + fp.get_array(), + left_usm, + right_usm, + output.get_array(), + exec_q, + depends=_manager.submitted_events, + ) + _manager.add_event_pair(mem_ev, ht_ev) + + return output + + _LCM_DOCSTRING = r""" Returns the lowest common multiple of :math:`\abs{x1}` and :math:`\abs{x2}`. diff --git a/dpnp/tests/test_mathematical.py b/dpnp/tests/test_mathematical.py index 7e917b665ee0..f125187d6bf4 100644 --- a/dpnp/tests/test_mathematical.py +++ b/dpnp/tests/test_mathematical.py @@ -1143,6 +1143,160 @@ def test_complex(self, xp): assert_raises((ValueError, TypeError), xp.i0, a) +class TestInterp: + @pytest.mark.parametrize( + "dtype_x", get_all_dtypes(no_bool=True, no_complex=True) + ) + @pytest.mark.parametrize("dtype_y", get_all_dtypes(no_bool=True)) + def test_all_dtypes(self, dtype_x, dtype_y): + x = numpy.linspace(0.1, 9.9, 20).astype(dtype_x) + xp = numpy.linspace(0.0, 10.0, 5).astype(dtype_x) + fp = (xp * 1.5 + 1).astype(dtype_y) + + ix = dpnp.array(x) + ixp = dpnp.array(xp) + ifp = dpnp.array(fp) + + expected = numpy.interp(x, xp, fp) + result = dpnp.interp(ix, ixp, ifp) + assert_dtype_allclose(result, expected) + + @pytest.mark.parametrize( + "dtype_x", get_all_dtypes(no_bool=True, no_complex=True) + ) + @pytest.mark.parametrize("dtype_y", get_complex_dtypes()) + def test_complex_fp(self, dtype_x, dtype_y): + x = numpy.array([0.25, 0.75], dtype=dtype_x) + xp = numpy.array([0.0, 1.0], dtype=dtype_x) + fp = numpy.array([1 + 1j, 3 + 3j], dtype=dtype_y) + + ix = dpnp.array(x) + ixp = dpnp.array(xp) + ifp = dpnp.array(fp) + + expected = numpy.interp(x, xp, fp) + result = dpnp.interp(ix, ixp, ifp) + assert_dtype_allclose(result, expected) + + @pytest.mark.parametrize( + "dtype", get_all_dtypes(no_bool=True, no_complex=True) + ) + def test_left_right_args(self, dtype): + x = numpy.array([-1, 0, 1, 2, 3, 4, 5, 6], dtype=dtype) + xp = numpy.array([0, 3, 6], dtype=dtype) + fp = numpy.array([0, 9, 18], dtype=dtype) + + ix = dpnp.array(x) + ixp = dpnp.array(xp) + ifp = dpnp.array(fp) + + expected = numpy.interp(x, xp, fp, left=-40, right=40) + result = dpnp.interp(ix, ixp, ifp, left=-40, right=40) + assert_dtype_allclose(result, expected) + + @pytest.mark.parametrize("val", [numpy.nan, numpy.inf, -numpy.inf]) + def test_naninf(self, val): + x = numpy.array([0, 1, 2, val]) + xp = numpy.array([0, 1, 2]) + fp = numpy.array([10, 20, 30]) + + ix = dpnp.array(x) + ixp = dpnp.array(xp) + ifp = dpnp.array(fp) + + expected = numpy.interp(x, xp, fp) + result = dpnp.interp(ix, ixp, ifp) + assert_dtype_allclose(result, expected) + + def test_empty_x(self): + x = numpy.array([]) + xp = numpy.array([0, 1]) + fp = numpy.array([10, 20]) + + ix = dpnp.array(x) + ixp = dpnp.array(xp) + ifp = dpnp.array(fp) + + expected = numpy.interp(x, xp, fp) + result = dpnp.interp(ix, ixp, ifp) + assert_dtype_allclose(result, expected) + + @pytest.mark.parametrize("dtype", get_float_dtypes()) + def test_period(self, dtype): + x = numpy.array([-180, 0, 180], dtype=dtype) + xp = numpy.array([-90, 0, 90], dtype=dtype) + fp = numpy.array([0, 1, 0], dtype=dtype) + + ix = dpnp.array(x) + ixp = dpnp.array(xp) + ifp = dpnp.array(fp) + + expected = numpy.interp(x, xp, fp, period=180) + result = dpnp.interp(ix, ixp, ifp, period=180) + assert_dtype_allclose(result, expected) + + def test_errors(self): + x = dpnp.array([0.5]) + + # xp and fp have different lengths + xp = dpnp.array([0]) + fp = dpnp.array([1, 2]) + assert_raises(ValueError, dpnp.interp, x, xp, fp) + + # xp is not 1D + xp = dpnp.array([[0, 1]]) + fp = dpnp.array([1, 2]) + assert_raises(ValueError, dpnp.interp, x, xp, fp) + + # fp is not 1D + xp = dpnp.array([0, 1]) + fp = dpnp.array([[1, 2]]) + assert_raises(ValueError, dpnp.interp, x, xp, fp) + + # xp and fp are empty + xp = dpnp.array([]) + fp = dpnp.array([]) + assert_raises(ValueError, dpnp.interp, x, xp, fp) + + # x complex + x_complex = dpnp.array([1 + 2j]) + xp = dpnp.array([0.0, 2.0]) + fp = dpnp.array([0.0, 1.0]) + assert_raises(TypeError, dpnp.interp, x_complex, xp, fp) + + # period is zero + x = dpnp.array([1.0]) + xp = dpnp.array([0.0, 2.0]) + fp = dpnp.array([0.0, 1.0]) + assert_raises(ValueError, dpnp.interp, x, xp, fp, period=0) + + # period is not scalar or 0-dim + assert_raises(TypeError, dpnp.interp, x, xp, fp, period=[180]) + + # period has a different SYCL queue + q1 = dpctl.SyclQueue() + q2 = dpctl.SyclQueue() + + x = dpnp.array([1.0], sycl_queue=q1) + xp = dpnp.array([0.0, 2.0], sycl_queue=q1) + fp = dpnp.array([0.0, 1.0], sycl_queue=q1) + period = dpnp.array([180], sycl_queue=q2) + assert_raises(ValueError, dpnp.interp, x, xp, fp, period=period) + + # left is not scalar or 0-dim + left = dpnp.array([1.0]) + assert_raises(ValueError, dpnp.interp, x, xp, fp, left=left) + + # left is 1-d array + left = dpnp.array([1.0]) + assert_raises(ValueError, dpnp.interp, x, xp, fp, left=left) + + # left has a different SYCL queue + left = dpnp.array(1.0, sycl_queue=q2) + if q1 != q2: + assert_raises(ValueError, dpnp.interp, x, xp, fp, left=left) + + @pytest.mark.parametrize( "rhs", [[[1, 2, 3], [4, 5, 6]], [2.0, 1.5, 1.0], 3, 0.3] ) diff --git a/dpnp/tests/test_sycl_queue.py b/dpnp/tests/test_sycl_queue.py index b0112702e308..1f015d6ab2dd 100644 --- a/dpnp/tests/test_sycl_queue.py +++ b/dpnp/tests/test_sycl_queue.py @@ -1453,6 +1453,24 @@ def test_choose(device): assert_sycl_queue_equal(result.sycl_queue, chc.sycl_queue) +@pytest.mark.parametrize("device", valid_dev, ids=dev_ids) +@pytest.mark.parametrize("left", [None, dpnp.array(-1.0)]) +@pytest.mark.parametrize("right", [None, dpnp.array(99.0)]) +@pytest.mark.parametrize("period", [None, dpnp.array(180.0)]) +def test_interp(device, left, right, period): + x = dpnp.linspace(0.1, 9.9, 20, device=device) + xp = dpnp.linspace(0.0, 10.0, 5, sycl_queue=x.sycl_queue) + fp = dpnp.array(xp * 2 + 1, sycl_queue=x.sycl_queue) + + l = None if left is None else dpnp.array(left, sycl_queue=x.sycl_queue) + r = None if right is None else dpnp.array(right, sycl_queue=x.sycl_queue) + p = None if period is None else dpnp.array(period, sycl_queue=x.sycl_queue) + + result = dpnp.interp(x, xp, fp, left=l, right=r, period=p) + + assert_sycl_queue_equal(result.sycl_queue, x.sycl_queue) + + @pytest.mark.parametrize("device", valid_dev, ids=dev_ids) class TestLinAlgebra: @pytest.mark.parametrize( diff --git a/dpnp/tests/test_usm_type.py b/dpnp/tests/test_usm_type.py index 1d512ce111a6..ad8b6ba9403f 100644 --- a/dpnp/tests/test_usm_type.py +++ b/dpnp/tests/test_usm_type.py @@ -1268,6 +1268,65 @@ def test_choose(usm_type_x, usm_type_ind): assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_ind]) +class TestInterp: + @pytest.mark.parametrize("usm_type_x", list_of_usm_types) + @pytest.mark.parametrize("usm_type_xp", list_of_usm_types) + @pytest.mark.parametrize("usm_type_fp", list_of_usm_types) + def test_basic(self, usm_type_x, usm_type_xp, usm_type_fp): + x = dpnp.linspace(0.1, 9.9, 20, usm_type=usm_type_x) + xp = dpnp.linspace(0.0, 10.0, 5, usm_type=usm_type_xp) + fp = dpnp.array(xp * 2 + 1, usm_type=usm_type_fp) + + result = dpnp.interp(x, xp, fp) + + assert x.usm_type == usm_type_x + assert xp.usm_type == usm_type_xp + assert fp.usm_type == usm_type_fp + assert result.usm_type == du.get_coerced_usm_type( + [usm_type_x, usm_type_xp, usm_type_fp] + ) + + @pytest.mark.parametrize("usm_type_x", list_of_usm_types) + @pytest.mark.parametrize("usm_type_left", list_of_usm_types) + @pytest.mark.parametrize("usm_type_right", list_of_usm_types) + def test_left_right(self, usm_type_x, usm_type_left, usm_type_right): + x = dpnp.linspace(-1.0, 11.0, 5, usm_type=usm_type_x) + xp = dpnp.linspace(0.0, 10.0, 5, usm_type=usm_type_x) + fp = dpnp.array(xp * 2 + 1, usm_type=usm_type_x) + + left = dpnp.array(-100, usm_type=usm_type_left) + right = dpnp.array(100, usm_type=usm_type_right) + + result = dpnp.interp(x, xp, fp, left=left, right=right) + + assert left.usm_type == usm_type_left + assert right.usm_type == usm_type_right + assert result.usm_type == du.get_coerced_usm_type( + [ + x.usm_type, + xp.usm_type, + fp.usm_type, + left.usm_type, + right.usm_type, + ] + ) + + @pytest.mark.parametrize("usm_type_x", list_of_usm_types) + @pytest.mark.parametrize("usm_type_period", list_of_usm_types) + def test_period(self, usm_type_x, usm_type_period): + x = dpnp.linspace(0.1, 9.9, 20, usm_type=usm_type_x) + xp = dpnp.linspace(0.0, 10.0, 5, usm_type=usm_type_x) + fp = dpnp.array(xp * 2 + 1, usm_type=usm_type_x) + period = dpnp.array(10.0, usm_type=usm_type_period) + + result = dpnp.interp(x, xp, fp, period=period) + + assert period.usm_type == usm_type_period + assert result.usm_type == du.get_coerced_usm_type( + [x.usm_type, xp.usm_type, fp.usm_type, period.usm_type] + ) + + @pytest.mark.parametrize("usm_type", list_of_usm_types) class TestLinAlgebra: @pytest.mark.parametrize( diff --git a/dpnp/tests/third_party/cupy/math_tests/test_misc.py b/dpnp/tests/third_party/cupy/math_tests/test_misc.py index 7746c56d3253..4542c51de33e 100644 --- a/dpnp/tests/third_party/cupy/math_tests/test_misc.py +++ b/dpnp/tests/third_party/cupy/math_tests/test_misc.py @@ -367,10 +367,9 @@ def test_real_if_close_with_float_tol_false(self, xp, dtype): assert x.dtype == out.dtype return out - @pytest.mark.skip("interp() is not supported yet") @testing.for_all_dtypes(name="dtype_x", no_bool=True, no_complex=True) @testing.for_all_dtypes(name="dtype_y", no_bool=True) - @testing.numpy_cupy_allclose(atol=1e-5) + @testing.numpy_cupy_allclose(atol=1e-5, type_check=has_support_aspect64()) def test_interp(self, xp, dtype_y, dtype_x): # interpolate at points on and outside the boundaries x = xp.asarray([0, 1, 2, 4, 6, 8, 9, 10], dtype=dtype_x) @@ -378,10 +377,9 @@ def test_interp(self, xp, dtype_y, dtype_x): fy = xp.sin(fx).astype(dtype_y) return xp.interp(x, fx, fy) - @pytest.mark.skip("interp() is not supported yet") @testing.for_all_dtypes(name="dtype_x", no_bool=True, no_complex=True) @testing.for_all_dtypes(name="dtype_y", no_bool=True) - @testing.numpy_cupy_allclose(atol=1e-5) + @testing.numpy_cupy_allclose(atol=1e-5, type_check=has_support_aspect64()) def test_interp_period(self, xp, dtype_y, dtype_x): # interpolate at points on and outside the boundaries x = xp.asarray([0, 1, 2, 4, 6, 8, 9, 10], dtype=dtype_x) @@ -389,10 +387,9 @@ def test_interp_period(self, xp, dtype_y, dtype_x): fy = xp.sin(fx).astype(dtype_y) return xp.interp(x, fx, fy, period=5) - @pytest.mark.skip("interp() is not supported yet") @testing.for_all_dtypes(name="dtype_x", no_bool=True, no_complex=True) @testing.for_all_dtypes(name="dtype_y", no_bool=True) - @testing.numpy_cupy_allclose(atol=1e-5) + @testing.numpy_cupy_allclose(atol=1e-5, type_check=has_support_aspect64()) def test_interp_left_right(self, xp, dtype_y, dtype_x): # interpolate at points on and outside the boundaries x = xp.asarray([0, 1, 2, 4, 6, 8, 9, 10], dtype=dtype_x) @@ -402,11 +399,10 @@ def test_interp_left_right(self, xp, dtype_y, dtype_x): right = 20 return xp.interp(x, fx, fy, left, right) - @pytest.mark.skip("interp() is not supported yet") @testing.with_requires("numpy>=1.17.0") @testing.for_all_dtypes(name="dtype_x", no_bool=True, no_complex=True) @testing.for_dtypes("efdFD", name="dtype_y") - @testing.numpy_cupy_allclose(atol=1e-5) + @testing.numpy_cupy_allclose(atol=1e-5, type_check=has_support_aspect64()) def test_interp_nan_fy(self, xp, dtype_y, dtype_x): # interpolate at points on and outside the boundaries x = xp.asarray([0, 1, 2, 4, 6, 8, 9, 10], dtype=dtype_x) @@ -415,11 +411,10 @@ def test_interp_nan_fy(self, xp, dtype_y, dtype_x): fy[0] = fy[2] = fy[-1] = numpy.nan return xp.interp(x, fx, fy) - @pytest.mark.skip("interp() is not supported yet") @testing.with_requires("numpy>=1.17.0") @testing.for_float_dtypes(name="dtype_x") @testing.for_dtypes("efdFD", name="dtype_y") - @testing.numpy_cupy_allclose(atol=1e-5) + @testing.numpy_cupy_allclose(atol=1e-5, type_check=has_support_aspect64()) def test_interp_nan_fx(self, xp, dtype_y, dtype_x): # interpolate at points on and outside the boundaries x = xp.asarray([0, 1, 2, 4, 6, 8, 9, 10], dtype=dtype_x) @@ -428,11 +423,10 @@ def test_interp_nan_fx(self, xp, dtype_y, dtype_x): fx[-1] = numpy.nan # x and fx must remain sorted (NaNs are the last) return xp.interp(x, fx, fy) - @pytest.mark.skip("interp() is not supported yet") @testing.with_requires("numpy>=1.17.0") @testing.for_float_dtypes(name="dtype_x") @testing.for_dtypes("efdFD", name="dtype_y") - @testing.numpy_cupy_allclose(atol=1e-5) + @testing.numpy_cupy_allclose(atol=1e-5, type_check=has_support_aspect64()) def test_interp_nan_x(self, xp, dtype_y, dtype_x): # interpolate at points on and outside the boundaries x = xp.asarray([0, 1, 2, 4, 6, 8, 9, 10], dtype=dtype_x) @@ -441,11 +435,10 @@ def test_interp_nan_x(self, xp, dtype_y, dtype_x): x[-1] = numpy.nan # x and fx must remain sorted (NaNs are the last) return xp.interp(x, fx, fy) - @pytest.mark.skip("interp() is not supported yet") @testing.with_requires("numpy>=1.17.0") @testing.for_all_dtypes(name="dtype_x", no_bool=True, no_complex=True) @testing.for_dtypes("efdFD", name="dtype_y") - @testing.numpy_cupy_allclose(atol=1e-5) + @testing.numpy_cupy_allclose(atol=1e-5, type_check=has_support_aspect64()) def test_interp_inf_fy(self, xp, dtype_y, dtype_x): # interpolate at points on and outside the boundaries x = xp.asarray([0, 1, 2, 4, 6, 8, 9, 10], dtype=dtype_x) @@ -454,11 +447,10 @@ def test_interp_inf_fy(self, xp, dtype_y, dtype_x): fy[0] = fy[2] = fy[-1] = numpy.inf return xp.interp(x, fx, fy) - @pytest.mark.skip("interp() is not supported yet") @testing.with_requires("numpy>=1.17.0") @testing.for_float_dtypes(name="dtype_x") @testing.for_dtypes("efdFD", name="dtype_y") - @testing.numpy_cupy_allclose(atol=1e-5) + @testing.numpy_cupy_allclose(atol=1e-5, type_check=has_support_aspect64()) def test_interp_inf_fx(self, xp, dtype_y, dtype_x): # interpolate at points on and outside the boundaries x = xp.asarray([0, 1, 2, 4, 6, 8, 9, 10], dtype=dtype_x) @@ -467,11 +459,10 @@ def test_interp_inf_fx(self, xp, dtype_y, dtype_x): fx[-1] = numpy.inf # x and fx must remain sorted return xp.interp(x, fx, fy) - @pytest.mark.skip("interp() is not supported yet") @testing.with_requires("numpy>=1.17.0") @testing.for_float_dtypes(name="dtype_x") @testing.for_dtypes("efdFD", name="dtype_y") - @testing.numpy_cupy_allclose(atol=1e-5) + @testing.numpy_cupy_allclose(atol=1e-5, type_check=has_support_aspect64()) def test_interp_inf_x(self, xp, dtype_y, dtype_x): # interpolate at points on and outside the boundaries x = xp.asarray([0, 1, 2, 4, 6, 8, 9, 10], dtype=dtype_x) @@ -480,10 +471,9 @@ def test_interp_inf_x(self, xp, dtype_y, dtype_x): x[-1] = numpy.inf # x and fx must remain sorted return xp.interp(x, fx, fy) - @pytest.mark.skip("interp() is not supported yet") @testing.for_all_dtypes(name="dtype_x", no_bool=True, no_complex=True) @testing.for_all_dtypes(name="dtype_y", no_bool=True) - @testing.numpy_cupy_allclose(atol=1e-5) + @testing.numpy_cupy_allclose(atol=1e-5, type_check=has_support_aspect64()) def test_interp_size1(self, xp, dtype_y, dtype_x): # interpolate at points on and outside the boundaries x = xp.asarray([0, 1, 2, 4, 6, 8, 9, 10], dtype=dtype_x) @@ -493,11 +483,10 @@ def test_interp_size1(self, xp, dtype_y, dtype_x): right = 20 return xp.interp(x, fx, fy, left, right) - @pytest.mark.skip("interp() is not supported yet") @testing.with_requires("numpy>=1.17.0") @testing.for_float_dtypes(name="dtype_x") @testing.for_dtypes("efdFD", name="dtype_y") - @testing.numpy_cupy_allclose(atol=1e-5) + @testing.numpy_cupy_allclose(atol=1e-5, type_check=has_support_aspect64()) def test_interp_inf_to_nan(self, xp, dtype_y, dtype_x): # from NumPy's test_non_finite_inf x = xp.asarray([0.5], dtype=dtype_x)