diff --git a/hexrd/material/crystallography.py b/hexrd/material/crystallography.py index ee186b364..d3d5c63c4 100644 --- a/hexrd/material/crystallography.py +++ b/hexrd/material/crystallography.py @@ -700,6 +700,9 @@ def __init__(self, raise RuntimeError('have unparsed keyword arguments with keys: ' + str(list(kwargs.keys()))) + # This is only used to calculate the structure factor if invalidated + self.__unitcell = None + self.__calc() return @@ -935,7 +938,27 @@ def set_wavelength(self, wavelength): wavelength = property(get_wavelength, set_wavelength, None) + def invalidate_structure_factor(self, unitcell): + # It can be expensive to compute the structure factor, so provide the + # option to just invalidate it, while providing a unit cell, so that + # it can be lazily computed from the unit cell. + self.__structFact = None + self._powder_intensity = None + self.__unitcell = unitcell + + def _compute_sf_if_needed(self): + any_invalid = ( + self.__structFact is None or + self._powder_intensity is None + ) + if any_invalid and self.__unitcell is not None: + # Compute the structure factor first. + # This can be expensive to do, so we lazily compute it when needed. + hkls = self.getHKLs(allHKLs=True) + self.set_structFact(self.__unitcell.CalcXRSF(hkls)) + def get_structFact(self): + self._compute_sf_if_needed() return self.__structFact[~self.exclusions] def set_structFact(self, structFact): @@ -953,6 +976,7 @@ def set_structFact(self, structFact): @property def powder_intensity(self): + self._compute_sf_if_needed() return self._powder_intensity[~self.exclusions] @staticmethod diff --git a/hexrd/material/material.py b/hexrd/material/material.py index abb139f08..ffd775248 100644 --- a/hexrd/material/material.py +++ b/hexrd/material/material.py @@ -226,7 +226,7 @@ def __init__( self.reset_v0() self._newPdata() - self.update_structure_factor() + self.invalidate_structure_factor() def __str__(self): """String representation""" @@ -291,7 +291,7 @@ def _newUnitcell(self): def _hkls_changed(self): # Call this when something happens that changes the hkls... self._newPdata() - self.update_structure_factor() + self.invalidate_structure_factor() def _newPdata(self): """Create a new plane data instance if the hkls have changed""" @@ -405,10 +405,8 @@ def enable_hkls_below_tth(self, tth_threshold=90.0): self._pData.exclusions = dflt_excl - def update_structure_factor(self): - hkls = self.planeData.getHKLs(allHKLs=True) - sf = self.unitcell.CalcXRSF(hkls) - self.planeData.set_structFact(sf) + def invalidate_structure_factor(self): + self.planeData.invalidate_structure_factor(self.unitcell) def compute_powder_overlay( self, ttharray=np.linspace(0, 80, 2000), fwhm=0.25, scale=1.0 @@ -1268,7 +1266,7 @@ def charge(self, vals): self._charge = vals # self._newUnitcell() - # self.update_structure_factor() + # self.invalidate_structure_factor() @property def absorption_length(self): @@ -1390,7 +1388,7 @@ def _set_atomdata(self, atomtype, atominfo, U, charge): self.charge = charge self._newUnitcell() - self.update_structure_factor() + self.invalidate_structure_factor() # # ========== Methods diff --git a/hexrd/matrixutil.py b/hexrd/matrixutil.py index 9ee43b5a4..12569a21b 100644 --- a/hexrd/matrixutil.py +++ b/hexrd/matrixutil.py @@ -32,11 +32,12 @@ from scipy import sparse +from hexrd.utils.decorators import numba_njit_if_available from hexrd import constants from hexrd.constants import USE_NUMBA if USE_NUMBA: import numba - + from numba import prange # module variables sqr6i = 1./np.sqrt(6.) @@ -582,7 +583,7 @@ def uniqueVectors(v, tol=1.0e-12): return vSrt[:, ivInd[0:nUniq]] -def findDuplicateVectors(vec, tol=vTol, equivPM=False): +def findDuplicateVectors_old(vec, tol=vTol, equivPM=False): """ Find vectors in an array that are equivalent to within a specified tolerance @@ -682,6 +683,92 @@ def findDuplicateVectors(vec, tol=vTol, equivPM=False): return eqv, uid +def findDuplicateVectors(vec, tol=vTol, equivPM=False): + eqv = _findduplicatevectors(vec, tol, equivPM) + uid = np.arange(0, vec.shape[1], dtype=np.int64) + mask = ~np.isnan(eqv) + idx = eqv[mask].astype(np.int64) + uid2 = list(np.delete(uid, idx)) + eqv2 = [] + for ii in range(eqv.shape[0]): + v = eqv[ii, mask[ii, :]] + if v.shape[0] > 0: + eqv2.append([ii] + list(v.astype(np.int64))) + return eqv2, uid2 + + +@numba_njit_if_available(cache=True, nogil=True) +def _findduplicatevectors(vec, tol, equivPM): + """ + Find vectors in an array that are equivalent to within + a specified tolerance. code is accelerated by numba + + USAGE: + + eqv = DuplicateVectors(vec, *tol) + + INPUT: + + 1) vec is n x m, a double array of m horizontally concatenated + n-dimensional vectors. + *2) tol is 1 x 1, a scalar tolerance. If not specified, the default + tolerance is 1e-14. + *3) set equivPM to True if vec and -vec + are to be treated as equivalent + + OUTPUT: + + 1) eqv is 1 x p, a list of p equivalence relationships. + + NOTES: + + Each equivalence relationship is a 1 x q vector of indices that + represent the locations of duplicate columns/entries in the array + vec. For example: + + | 1 2 2 2 1 2 7 | + vec = | | + | 2 3 5 3 2 3 3 | + + eqv = [[1x2 double] [1x3 double]], where + + eqv[0] = [0 4] + eqv[1] = [1 3 5] + """ + + if equivPM: + vec2 = -vec.copy() + + n = vec.shape[0] + m = vec.shape[1] + + eqv = np.zeros((m, m), dtype=np.float64) + eqv[:] = np.nan + eqv_elem_master = [] + + for ii in range(m): + ctr = 0 + eqv_elem = np.zeros((m, ), dtype=np.int64) + for jj in range(ii+1, m): + if not jj in eqv_elem_master: + if equivPM: + diff = np.sum(np.abs(vec[:, ii]-vec2[:, jj])) + diff2 = np.sum(np.abs(vec[:, ii]-vec[:, jj])) + if diff < tol or diff2 < tol: + eqv_elem[ctr] = jj + eqv_elem_master.append(jj) + ctr += 1 + else: + diff = np.sum(np.abs(vec[:, ii]-vec[:, jj])) + if diff < tol: + eqv_elem[ctr] = jj + eqv_elem_master.append(jj) + ctr += 1 + + for kk in range(ctr): + eqv[ii, kk] = eqv_elem[kk] + + return eqv def normvec(v): mag = np.linalg.norm(v) diff --git a/hexrd/utils/decorators.py b/hexrd/utils/decorators.py index 134b4497f..f6716d4fd 100644 --- a/hexrd/utils/decorators.py +++ b/hexrd/utils/decorators.py @@ -9,6 +9,7 @@ from collections import OrderedDict from functools import wraps +import numba import numpy as np import xxhash @@ -139,3 +140,20 @@ def decorator(func): from numba import prange else: prange = range + + +# A decorator to limit the number of numba threads +def limit_numba_threads(max_threads): + def decorator(func): + def wrapper(*args, **kwargs): + prev_num_threads = numba.get_num_threads() + new_num_threads = min(prev_num_threads, max_threads) + numba.set_num_threads(new_num_threads) + try: + return func(*args, **kwargs) + finally: + numba.set_num_threads(prev_num_threads) + + return wrapper + + return decorator