Skip to content

Commit

Permalink
Reduce the number of numba threads
Browse files Browse the repository at this point in the history
Too many numba threads are causing allocator contention. Therefore
we can limit the number of numba threads to 8.

Signed-off-by: Patrick Avery <[email protected]>
  • Loading branch information
psavery committed Jan 5, 2024
1 parent 7101380 commit 3f59e9e
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 15 deletions.
31 changes: 16 additions & 15 deletions hexrd/matrixutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

from scipy import sparse

from hexrd.utils.decorators import numba_njit_if_available
from hexrd.utils.decorators import limit_numba_threads, numba_njit_if_available
from hexrd import constants
from hexrd.constants import USE_NUMBA
if USE_NUMBA:
Expand Down Expand Up @@ -697,6 +697,9 @@ def findDuplicateVectors(vec, tol=vTol, equivPM=False):
return eqv2, uid2


# We found that too many threads causes allocator contention,
# so limit the number of threads here to just 8.
@limit_numba_threads(8)
@numba_njit_if_available(cache=True, nogil=True, parallel=True)
def _findduplicatevectors(vec, tol, equivPM):
"""
Expand Down Expand Up @@ -749,20 +752,18 @@ def _findduplicatevectors(vec, tol, equivPM):
ctr = 0
eqv_elem = np.zeros((m, ), dtype=np.int64)

for jj in prange(m):
if jj > ii:

if equivPM:
diff = np.sum(np.abs(vec[:, ii]-vec2[:, jj]))
diff2 = np.sum(np.abs(vec[:, ii]-vec[:, jj]))
if diff < tol or diff2 < tol:
eqv_elem[ctr] = jj
ctr += 1
else:
diff = np.sum(np.abs(vec[:, ii]-vec[:, jj]))
if diff < tol:
eqv_elem[ctr] = jj
ctr += 1
for jj in prange(ii, m):
if equivPM:
diff = np.sum(np.abs(vec[:, ii]-vec2[:, jj]))
diff2 = np.sum(np.abs(vec[:, ii]-vec[:, jj]))
if diff < tol or diff2 < tol:
eqv_elem[ctr] = jj
ctr += 1
else:
diff = np.sum(np.abs(vec[:, ii]-vec[:, jj]))
if diff < tol:
eqv_elem[ctr] = jj
ctr += 1

for kk in range(ctr):
eqv[ii, kk] = eqv_elem[kk]
Expand Down
18 changes: 18 additions & 0 deletions hexrd/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from collections import OrderedDict
from functools import wraps

import numba
import numpy as np
import xxhash

Expand Down Expand Up @@ -139,3 +140,20 @@ def decorator(func):
from numba import prange
else:
prange = range


# A decorator to limit the number of numba threads
def limit_numba_threads(max_threads):
def decorator(func):
def wrapper(*args, **kwargs):
prev_num_threads = numba.get_num_threads()
new_num_threads = min(prev_num_threads, max_threads)
numba.set_num_threads(new_num_threads)
try:
return func(*args, **kwargs)
finally:
numba.set_num_threads(prev_num_threads)

return wrapper

return decorator

0 comments on commit 3f59e9e

Please sign in to comment.