Skip to content

Commit

Permalink
Use use_threads_if
Browse files Browse the repository at this point in the history
  • Loading branch information
matwey committed Jan 29, 2025
1 parent 79e801d commit 34ab5b1
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ requires = [
"setuptools",
"wheel",
"numpy>=2.0",
"Cython",
"Cython==3.1.0a1",
]
build-backend = "setuptools.build_meta"

Expand Down
9 changes: 8 additions & 1 deletion src/coniferest/calc_paths_sum.pyx
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
# cython: profile=True

import numpy as np
from cython.parallel cimport prange, parallel

cimport numpy as np
cimport cython
cimport openmp


def calc_paths_sum(selector_t [::1] selectors,
Expand All @@ -21,6 +24,9 @@ def calc_paths_sum(selector_t [::1] selectors,
if indices[-1] > sellen:
raise ValueError('indices are out of range of the selectors')

if num_threads < 0:
num_threads = openmp.omp_get_max_threads()

_paths_sum(selectors, indices, data, paths_view, weights, num_threads, chunksize)
return paths

Expand Down Expand Up @@ -101,8 +107,9 @@ cdef void _paths_sum(selector_t [::1] selectors,
cdef selector_t selector
cdef Py_ssize_t tree_offset
cdef np.int32_t feature, i
cdef int use_threads_if = (2 * num_threads < data.shape[0])

with nogil, parallel(num_threads=num_threads):
with nogil, parallel(num_threads=num_threads, use_threads_if=use_threads_if):
trees = indices.shape[0] - 1

for x_index in prange(data.shape[0], schedule='static', chunksize=chunksize):
Expand Down

0 comments on commit 34ab5b1

Please sign in to comment.