Skip to content

Implement dpnp.interp() #2417

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 30 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
a24d367
Initial impl of dpnp.inter()
vlad-perevezentsev Mar 25, 2025
e2b20b0
Second impl with dispatch_vector[only floating]
vlad-perevezentsev Apr 2, 2025
f7d1da9
Implement interpolate_complex
vlad-perevezentsev Apr 2, 2025
e1b8698
Move interpolate backend to ufunc
vlad-perevezentsev Apr 2, 2025
0037455
Move def interp()to dpnp_iface_mathematical
vlad-perevezentsev Apr 2, 2025
7866eb8
Use dispatch vector and remove interpolate_complex_impl
vlad-perevezentsev Apr 2, 2025
51b3bde
Add more backend checks
vlad-perevezentsev Apr 2, 2025
ecfa37d
Add support left/right args
vlad-perevezentsev Apr 10, 2025
5d53f9c
Use get_usm_allocations in def interp
vlad-perevezentsev Apr 10, 2025
9dbc2c5
Pass idx as std::int64_t
vlad-perevezentsev Apr 11, 2025
1bafd7c
Add proper casting input array
vlad-perevezentsev Apr 11, 2025
2f43fd7
Update def interp to support period args
vlad-perevezentsev Apr 11, 2025
ae65091
Return fp[-1] instead of right_val for x==xp[-1]
vlad-perevezentsev Apr 11, 2025
771d3eb
Unskip cupy tests for interp
vlad-perevezentsev Apr 11, 2025
5cda3d2
Add dpnp tests for interp
vlad-perevezentsev Apr 11, 2025
a65a1dd
Update docstrings for def interp()
vlad-perevezentsev Apr 11, 2025
3146234
Merge master into impl_of_interp
vlad-perevezentsev Apr 11, 2025
99cc8b5
Remove lines after merging
vlad-perevezentsev Apr 11, 2025
5ec0738
Merge master into impl_of_interp
vlad-perevezentsev Apr 11, 2025
1263eb5
Add type_check flag to cupy tests
vlad-perevezentsev Apr 14, 2025
7c1fdf1
Merge master into impl_of_interp
vlad-perevezentsev Apr 14, 2025
b84dd7e
Add common_interpolate_checks with common utils
vlad-perevezentsev Apr 14, 2025
e9e357c
Reuse IsNan from common utils
vlad-perevezentsev Apr 14, 2025
50e4513
Remove dublicate copy
vlad-perevezentsev Apr 14, 2025
dbeb313
Add _validate_interp_param() function
vlad-perevezentsev Apr 14, 2025
dbb1b55
Impove code coverage
vlad-perevezentsev Apr 15, 2025
cbe7e7a
Add sycl_queue tests for interp
vlad-perevezentsev Apr 15, 2025
aa102bd
Add usm_type tests for interp()
vlad-perevezentsev Apr 15, 2025
28b2a52
Merge master into impl_of_interp
vlad-perevezentsev Apr 15, 2025
82c657e
Fix pre-commit remark
vlad-perevezentsev Apr 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dpnp/backend/extensions/ufunc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ set(_elementwise_sources
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/gcd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/heaviside.cpp
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/i0.cpp
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/interpolate.cpp
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/lcm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/ldexp.cpp
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/logaddexp2.cpp
Expand Down Expand Up @@ -69,6 +70,7 @@ endif()
set_target_properties(${python_module_name} PROPERTIES CMAKE_POSITION_INDEPENDENT_CODE ON)

target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../)
target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../common)

target_include_directories(${python_module_name} PUBLIC ${Dpctl_INCLUDE_DIR})
target_include_directories(${python_module_name} PUBLIC ${Dpctl_TENSOR_INCLUDE_DIR})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "gcd.hpp"
#include "heaviside.hpp"
#include "i0.hpp"
#include "interpolate.hpp"
#include "lcm.hpp"
#include "ldexp.hpp"
#include "logaddexp2.hpp"
Expand Down Expand Up @@ -64,6 +65,7 @@ void init_elementwise_functions(py::module_ m)
init_gcd(m);
init_heaviside(m);
init_i0(m);
init_interpolate(m);
init_lcm(m);
init_ldexp(m);
init_logaddexp2(m);
Expand Down
300 changes: 300 additions & 0 deletions dpnp/backend/extensions/ufunc/elementwise_functions/interpolate.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,300 @@
//*****************************************************************************
// Copyright (c) 2025, Intel Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
// - Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
// - Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
// THE POSSIBILITY OF SUCH DAMAGE.
//*****************************************************************************

#include <complex>
#include <vector>

#include "dpctl4pybind11.hpp"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

// dpctl tensor headers
#include "utils/output_validation.hpp"
#include "utils/type_dispatch.hpp"

#include "kernels/elementwise_functions/interpolate.hpp"

#include "ext/validation_utils.hpp"

namespace py = pybind11;
namespace td_ns = dpctl::tensor::type_dispatch;

using ext::validation::array_names;
using ext::validation::array_ptr;
using ext::validation::common_checks;

namespace dpnp::extensions::ufunc
{

namespace impl
{

template <typename T>
struct value_type_of
{
using type = T;
};

template <typename T>
struct value_type_of<std::complex<T>>
{
using type = T;
};

template <typename T>
using value_type_of_t = typename value_type_of<T>::type;

typedef sycl::event (*interpolate_fn_ptr_t)(sycl::queue &,
const void *, // x
const void *, // idx
const void *, // xp
const void *, // fp
const void *, // left
const void *, // right
void *, // out
std::size_t, // n
std::size_t, // xp_size
const std::vector<sycl::event> &);

template <typename T>
sycl::event interpolate_call(sycl::queue &exec_q,
const void *vx,
const void *vidx,
const void *vxp,
const void *vfp,
const void *vleft,
const void *vright,
void *vout,
std::size_t n,
std::size_t xp_size,
const std::vector<sycl::event> &depends)
{
using dpctl::tensor::type_utils::is_complex_v;
using TCoord = std::conditional_t<is_complex_v<T>, value_type_of_t<T>, T>;

const TCoord *x = static_cast<const TCoord *>(vx);
const std::int64_t *idx = static_cast<const std::int64_t *>(vidx);
const TCoord *xp = static_cast<const TCoord *>(vxp);
const T *fp = static_cast<const T *>(vfp);
const T *left = static_cast<const T *>(vleft);
const T *right = static_cast<const T *>(vright);
T *out = static_cast<T *>(vout);

using dpnp::kernels::interpolate::interpolate_impl;
sycl::event interpolate_ev = interpolate_impl<TCoord, T>(
exec_q, x, idx, xp, fp, left, right, out, n, xp_size, depends);

return interpolate_ev;
}

interpolate_fn_ptr_t interpolate_dispatch_vector[td_ns::num_types];

void common_interpolate_checks(
const dpctl::tensor::usm_ndarray &x,
const dpctl::tensor::usm_ndarray &idx,
const dpctl::tensor::usm_ndarray &xp,
const dpctl::tensor::usm_ndarray &fp,
const dpctl::tensor::usm_ndarray &out,
const std::optional<const dpctl::tensor::usm_ndarray> &left,
const std::optional<const dpctl::tensor::usm_ndarray> &right)
{
array_names names = {{&x, "x"}, {&xp, "xp"}, {&fp, "fp"}, {&out, "out"}};

auto array_types = td_ns::usm_ndarray_types();
int x_type_id = array_types.typenum_to_lookup_id(x.get_typenum());
int xp_type_id = array_types.typenum_to_lookup_id(xp.get_typenum());
int fp_type_id = array_types.typenum_to_lookup_id(fp.get_typenum());
int out_type_id = array_types.typenum_to_lookup_id(out.get_typenum());

if (x_type_id != xp_type_id) {
throw py::value_error("x and xp must have the same dtype");
}
if (fp_type_id != out_type_id) {
throw py::value_error("fp and out must have the same dtype");
}

if (left) {
const auto &l = left.value();
names.insert({&l, "left"});
if (l.get_ndim() != 0) {
throw py::value_error("left must be a zero-dimensional array");
}

int left_type_id = array_types.typenum_to_lookup_id(l.get_typenum());
if (left_type_id != fp_type_id) {
throw py::value_error(
"left must have the same dtype as fp and out");
}
}

if (right) {
const auto &r = right.value();
names.insert({&r, "right"});
if (r.get_ndim() != 0) {
throw py::value_error("right must be a zero-dimensional array");
}

int right_type_id = array_types.typenum_to_lookup_id(r.get_typenum());
if (right_type_id != fp_type_id) {
throw py::value_error(
"right must have the same dtype as fp and out");
}
}

common_checks({&x, &xp, &fp, left ? &left.value() : nullptr,
right ? &right.value() : nullptr},
{&out}, names);

if (x.get_ndim() != 1 || xp.get_ndim() != 1 || fp.get_ndim() != 1 ||
idx.get_ndim() != 1 || out.get_ndim() != 1)
{
throw py::value_error("All arrays must be one-dimensional");
}

if (xp.get_size() != fp.get_size()) {
throw py::value_error("xp and fp must have the same size");
}

if (x.get_size() != out.get_size() || x.get_size() != idx.get_size()) {
throw py::value_error("x, idx, and out must have the same size");
}
}

std::pair<sycl::event, sycl::event>
py_interpolate(const dpctl::tensor::usm_ndarray &x,
const dpctl::tensor::usm_ndarray &idx,
const dpctl::tensor::usm_ndarray &xp,
const dpctl::tensor::usm_ndarray &fp,
std::optional<const dpctl::tensor::usm_ndarray> &left,
std::optional<const dpctl::tensor::usm_ndarray> &right,
dpctl::tensor::usm_ndarray &out,
sycl::queue &exec_q,
const std::vector<sycl::event> &depends)
{
if (x.get_size() == 0) {
return {sycl::event(), sycl::event()};
}

common_interpolate_checks(x, idx, xp, fp, out, left, right);

int out_typenum = out.get_typenum();

auto array_types = td_ns::usm_ndarray_types();
int out_type_id = array_types.typenum_to_lookup_id(out_typenum);

auto fn = interpolate_dispatch_vector[out_type_id];
if (!fn) {
throw py::type_error("Unsupported dtype");
}

std::size_t n = x.get_size();
std::size_t xp_size = xp.get_size();

void *left_ptr = left ? left.value().get_data() : nullptr;
void *right_ptr = right ? right.value().get_data() : nullptr;

sycl::event ev =
fn(exec_q, x.get_data(), idx.get_data(), xp.get_data(), fp.get_data(),
left_ptr, right_ptr, out.get_data(), n, xp_size, depends);

sycl::event args_ev;

if (left && right) {
args_ev = dpctl::utils::keep_args_alive(
exec_q, {x, idx, xp, fp, out, left.value(), right.value()}, {ev});
}
else if (left) {
args_ev = dpctl::utils::keep_args_alive(
exec_q, {x, idx, xp, fp, out, left.value()}, {ev});
}
else if (right) {
args_ev = dpctl::utils::keep_args_alive(
exec_q, {x, idx, xp, fp, out, right.value()}, {ev});
}
else {
args_ev =
dpctl::utils::keep_args_alive(exec_q, {x, idx, xp, fp, out}, {ev});
}

return std::make_pair(args_ev, ev);
}

/**
* @brief A factory to define pairs of supported types for which
* interpolate function is available.
*
* @tparam T Type of input vector `a` and of result vector `y`.
*/
template <typename T>
struct InterpolateOutputType
{
using value_type = typename std::disjunction<
td_ns::TypeMapResultEntry<T, sycl::half>,
td_ns::TypeMapResultEntry<T, float>,
td_ns::TypeMapResultEntry<T, double>,
td_ns::TypeMapResultEntry<T, std::complex<float>>,
td_ns::TypeMapResultEntry<T, std::complex<double>>,
td_ns::DefaultResultEntry<void>>::result_type;
};

template <typename fnT, typename T>
struct InterpolateFactory
{
fnT get()
{
if constexpr (std::is_same_v<
typename InterpolateOutputType<T>::value_type, void>)
{
return nullptr;
}
else {
return interpolate_call<T>;
}
}
};

void init_interpolate_dispatch_vectors()
{
using namespace td_ns;

DispatchVectorBuilder<interpolate_fn_ptr_t, InterpolateFactory, num_types>
dtb_interpolate;
dtb_interpolate.populate_dispatch_vector(interpolate_dispatch_vector);
}

} // namespace impl

void init_interpolate(py::module_ m)
{
impl::init_interpolate_dispatch_vectors();

using impl::py_interpolate;
m.def("_interpolate", &py_interpolate, "", py::arg("x"), py::arg("idx"),
py::arg("xp"), py::arg("fp"), py::arg("left"), py::arg("right"),
py::arg("out"), py::arg("sycl_queue"),
py::arg("depends") = py::list());
}

} // namespace dpnp::extensions::ufunc
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
//*****************************************************************************
// Copyright (c) 2025, Intel Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
// - Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
// - Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
// THE POSSIBILITY OF SUCH DAMAGE.
//*****************************************************************************

#pragma once

#include <pybind11/pybind11.h>

namespace py = pybind11;

namespace dpnp::extensions::ufunc
{
void init_interpolate(py::module_ m);
} // namespace dpnp::extensions::ufunc
Loading
Loading