Skip to content

Commit 77702b3

Browse files
authored
Improve performance of dnp.nan_to_num (#2228)
This PR adds a dedicated kernel for `dnp.nan_to_num` to improve its performance. This reduces the number of kernel calls to at most one in all cases. A kernel for both strided and contiguous inputs have been added, to avoid additional allocation of device memory for trivial strides when input is fully C- or F-contiguous. For example of performance gains, using Max GPU master: ```python In [1]: import dpnp as dnp In [2]: import numpy as np In [3]: x_np = np.random.randn(10**9) In [4]: x_np[np.random.choice(x_np.size, 200, replace=False)] = np.nan In [5]: x = dnp.asarray(x_np) In [6]: q = x.sycl_queue In [7]: %time r = dnp.nan_to_num(x); q.wait() CPU times: user 394 ms, sys: 43.8 ms, total: 438 ms Wall time: 304 ms In [8]: %time r = dnp.nan_to_num(x); q.wait() CPU times: user 333 ms, sys: 31.8 ms, total: 364 ms Wall time: 134 ms ``` on branch: ```python In [8]: %time r = dnp.nan_to_num(x); q.wait() CPU times: user 49.6 ms, sys: 8.1 ms, total: 57.7 ms Wall time: 60.9 ms In [9]: %time r = dnp.nan_to_num(x); q.wait() CPU times: user 22.9 ms, sys: 16 ms, total: 38.9 ms Wall time: 19.7 ms ```
1 parent 5b140db commit 77702b3

File tree

7 files changed

+770
-19
lines changed

7 files changed

+770
-19
lines changed

dpnp/backend/extensions/ufunc/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ set(_elementwise_sources
3838
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/lcm.cpp
3939
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/ldexp.cpp
4040
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/logaddexp2.cpp
41+
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/nan_to_num.cpp
4142
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/radians.cpp
4243
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/sinc.cpp
4344
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/spacing.cpp

dpnp/backend/extensions/ufunc/elementwise_functions/common.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
#include "lcm.hpp"
3939
#include "ldexp.hpp"
4040
#include "logaddexp2.hpp"
41+
#include "nan_to_num.hpp"
4142
#include "radians.hpp"
4243
#include "sinc.hpp"
4344
#include "spacing.hpp"
@@ -64,6 +65,7 @@ void init_elementwise_functions(py::module_ m)
6465
init_lcm(m);
6566
init_ldexp(m);
6667
init_logaddexp2(m);
68+
init_nan_to_num(m);
6769
init_radians(m);
6870
init_sinc(m);
6971
init_spacing(m);

0 commit comments

Comments
 (0)