Skip to content

Commit

Permalink
fix repeat nr_rks_group
Browse files Browse the repository at this point in the history
  • Loading branch information
Miroier committed Feb 23, 2024
1 parent e3fa550 commit 3daa808
Showing 1 changed file with 0 additions and 139 deletions.
139 changes: 0 additions & 139 deletions gpu4pyscf/dft/numint.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,145 +692,6 @@ def nr_rks(ni, mol, grids, xc_code, dms, relativity=0, hermi=1,

return nelec, excsum, vmat

def nr_rks_group(ni, mol, grids, xc_code, dms, relativity=0, hermi=1,
max_memory=2000, verbose=None):

log = logger.new_logger(mol, verbose)
xctype = ni._xc_type(xc_code)
opt = getattr(ni, 'gdftopt', None)
if opt is None:
ni.build(mol, grids.coords)
opt = ni.gdftopt

mo_coeff = getattr(dms, 'mo_coeff', None)
mo_occ = getattr(dms,'mo_occ', None)

mol = opt.mol
coeff = cupy.asarray(opt.coeff)
nao, nao0 = coeff.shape
dms = cupy.asarray(dms)
dm_shape = dms.shape
#dms = [coeff @ dm @ coeff.T for dm in dms.reshape(-1,nao0,nao0)]
dms = dms.reshape(-1,nao0,nao0)
dms = take_last2d(dms, opt.ao_idx)
nset = len(dms)

if mo_coeff is not None:
mo_coeff = mo_coeff[opt.ao_idx]

nelec = cupy.zeros(nset)
excsum = cupy.zeros(nset)
vmat = cupy.zeros((nset, nao, nao))
'''
ao_loc = mol.ao_loc_nr()
if USE_SPARSITY == 1:
nbins = NBINS * 2 - int(NBINS * np.log(ni.cutoff) / np.log(grids.cutoff))
pair2shls, pairs_locs = _make_pairs2shls_idx(ni.pair_mask, opt.l_bas_offsets, hermi)
if hermi:
pair2shls_full, pairs_locs_full = _make_pairs2shls_idx(ni.pair_mask,
opt.l_bas_offsets)
else:
pair2shls_full, pairs_locs_full = pair2shls, pairs_locs
'''
release_gpu_stack()
if xctype == 'LDA':
ao_deriv = 0
else:
ao_deriv = 1
with_lapl = MGGA_DENSITY_LAPL
ngrids = grids.weights.size
if xctype == 'LDA':
rho_tot = cupy.empty([nset,1,ngrids])
elif xctype == 'GGA':
rho_tot = cupy.empty([nset,4,ngrids])
else:
if with_lapl:
rho_tot = cupy.empty([nset,6,ngrids])
else:
rho_tot = cupy.empty([nset,5,ngrids])
p0 = p1 = 0
t1 = t0 = log.init_timer()
for ao_mask, idx, weight, _ in ni.block_loop(mol, grids, nao, ao_deriv):
p1 = p0 + weight.size
for i in range(nset):
if mo_coeff is None:
rho_tot[i,:,p0:p1] = eval_rho(mol, ao_mask, dms[i][np.ix_(idx,idx)], xctype=xctype, hermi=1, with_lapl=with_lapl)
else:
mo_coeff_mask = mo_coeff[idx,:]
rho_tot[i,:,p0:p1] = eval_rho2(mol, ao_mask, mo_coeff_mask, mo_occ, None, xctype, with_lapl)
p0 = p1
t1 = log.timer_debug2('eval rho slice', *t1)
t0 = log.timer_debug1('eval rho', *t0)

wv = []
for i in range(nset):
if xctype == 'LDA':
exc, vxc = ni.eval_xc_eff(xc_code, rho_tot[i][0], deriv=1, xctype=xctype)[:2]
else:
exc, vxc = ni.eval_xc_eff(xc_code, rho_tot[i], deriv=1, xctype=xctype)[:2]
vxc = cupy.asarray(vxc, order='C')
exc = cupy.asarray(exc, order='C')
den = rho_tot[i][0] * grids.weights
nelec[i] = den.sum()
excsum[i] = cupy.sum(den * exc[:,0])
wv.append(vxc * grids.weights)
if xctype == 'GGA':
wv[i][0] *= .5
if xctype == 'MGGA':
wv[i][[0,4]] *= .5
t0 = log.timer_debug1('eval vxc', *t0)

t1 = t0
p0 = p1 = 0
for ao_mask, idx, weight, _ in ni.block_loop(mol, grids, nao, ao_deriv):
p1 = p0 + weight.size
for i in range(nset):
if xctype == 'LDA':
if USE_SPARSITY == 2:
aow = _scale_ao(ao_mask, wv[i][0,p0:p1])
add_sparse(vmat[i], ao_mask.dot(aow.T), idx)
else:
raise NotImplementedError(f'USE_SPARSITY = {USE_SPARSITY} is not implemented')
elif xctype == 'GGA':
if USE_SPARSITY == 2:
aow = _scale_ao(ao_mask, wv[i][:,p0:p1])
add_sparse(vmat[i], ao_mask[0].dot(aow.T), idx)
else:
raise NotImplementedError(f'USE_SPARSITY = {USE_SPARSITY} is not implemented')
elif xctype == 'NLC':
raise NotImplementedError('NLC')
elif xctype == 'MGGA':
if USE_SPARSITY == 2:
aow = _scale_ao(ao_mask, wv[i][:4,p0:p1])
vtmp = ao_mask[0].dot(aow.T)
vtmp+= _tau_dot(ao_mask, ao_mask, wv[i][4,p0:p1])
add_sparse(vmat[i], vtmp, idx)
else:
raise NotImplementedError(f'USE_SPARSITY = {USE_SPARSITY} is not implemented')
elif xctype == 'HF':
pass
else:
raise NotImplementedError(f'numint.nr_rks for functional {xc_code}')
p0 = p1
t1 = log.timer_debug2('integration', *t1)
t0 = log.timer_debug1('vxc integration', *t0)
rev_ao_idx = opt.rev_ao_idx
vmat = take_last2d(vmat, rev_ao_idx)

if xctype != 'LDA':
transpose_sum(vmat)

if FREE_CUPY_CACHE:
dms = None
cupy.get_default_memory_pool().free_all_blocks()

if len(dm_shape) == 2:
nelec = nelec[0]
excsum = excsum[0]
vmat = vmat[0]

return nelec, excsum, vmat

def nr_rks_group(ni, mol, grids, xc_code, dms, relativity=0, hermi=1,
max_memory=2000, verbose=None):
log = logger.new_logger(mol, verbose)
Expand Down

0 comments on commit 3daa808

Please sign in to comment.