From 3f59e9e49b22beaec13bc71c1a2f969df11c598e Mon Sep 17 00:00:00 2001 From: Patrick Avery Date: Fri, 5 Jan 2024 16:37:25 -0600 Subject: [PATCH] Reduce the number of numba threads Too many numba threads are causing allocator contention. Therefore we can limit the number of numba threads to 8. Signed-off-by: Patrick Avery --- hexrd/matrixutil.py | 31 ++++++++++++++++--------------- hexrd/utils/decorators.py | 18 ++++++++++++++++++ 2 files changed, 34 insertions(+), 15 deletions(-) diff --git a/hexrd/matrixutil.py b/hexrd/matrixutil.py index aeb10d5d4..73a81a3a3 100644 --- a/hexrd/matrixutil.py +++ b/hexrd/matrixutil.py @@ -32,7 +32,7 @@ from scipy import sparse -from hexrd.utils.decorators import numba_njit_if_available +from hexrd.utils.decorators import limit_numba_threads, numba_njit_if_available from hexrd import constants from hexrd.constants import USE_NUMBA if USE_NUMBA: @@ -697,6 +697,9 @@ def findDuplicateVectors(vec, tol=vTol, equivPM=False): return eqv2, uid2 +# We found that too many threads causes allocator contention, +# so limit the number of threads here to just 8. +@limit_numba_threads(8) @numba_njit_if_available(cache=True, nogil=True, parallel=True) def _findduplicatevectors(vec, tol, equivPM): """ @@ -749,20 +752,18 @@ def _findduplicatevectors(vec, tol, equivPM): ctr = 0 eqv_elem = np.zeros((m, ), dtype=np.int64) - for jj in prange(m): - if jj > ii: - - if equivPM: - diff = np.sum(np.abs(vec[:, ii]-vec2[:, jj])) - diff2 = np.sum(np.abs(vec[:, ii]-vec[:, jj])) - if diff < tol or diff2 < tol: - eqv_elem[ctr] = jj - ctr += 1 - else: - diff = np.sum(np.abs(vec[:, ii]-vec[:, jj])) - if diff < tol: - eqv_elem[ctr] = jj - ctr += 1 + for jj in prange(ii, m): + if equivPM: + diff = np.sum(np.abs(vec[:, ii]-vec2[:, jj])) + diff2 = np.sum(np.abs(vec[:, ii]-vec[:, jj])) + if diff < tol or diff2 < tol: + eqv_elem[ctr] = jj + ctr += 1 + else: + diff = np.sum(np.abs(vec[:, ii]-vec[:, jj])) + if diff < tol: + eqv_elem[ctr] = jj + ctr += 1 for kk in range(ctr): eqv[ii, kk] = eqv_elem[kk] diff --git a/hexrd/utils/decorators.py b/hexrd/utils/decorators.py index 134b4497f..f6716d4fd 100644 --- a/hexrd/utils/decorators.py +++ b/hexrd/utils/decorators.py @@ -9,6 +9,7 @@ from collections import OrderedDict from functools import wraps +import numba import numpy as np import xxhash @@ -139,3 +140,20 @@ def decorator(func): from numba import prange else: prange = range + + +# A decorator to limit the number of numba threads +def limit_numba_threads(max_threads): + def decorator(func): + def wrapper(*args, **kwargs): + prev_num_threads = numba.get_num_threads() + new_num_threads = min(prev_num_threads, max_threads) + numba.set_num_threads(new_num_threads) + try: + return func(*args, **kwargs) + finally: + numba.set_num_threads(prev_num_threads) + + return wrapper + + return decorator