Skip to content

Commit

Permalink
improve code style
Browse files Browse the repository at this point in the history
  • Loading branch information
Miroier committed Feb 23, 2024
1 parent b250c9c commit 68286e1
Showing 1 changed file with 27 additions and 29 deletions.
56 changes: 27 additions & 29 deletions gpu4pyscf/dft/numint.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,23 +741,22 @@ def nr_rks_group(ni, mol, grids, xc_code, dms, relativity=0, hermi=1,
t1 = t0 = log.init_timer()
# TODO: replace ni.block_loop with ni.grouped_block_loop
for ao_mask_group, idx_group, weight_group, _ in ni.grouped_block_loop(mol, grids, nao, ao_deriv):
groups = len(ao_mask_group)
p0_raw = p0
for i in range(nset):
p0 = p0_raw
if mo_coeff is None:
for groups_idx in range(groups):
p1 = p0 + weight_group[groups_idx].size
rho_tot[i,:,p0:p1] = eval_rho(mol, ao_mask_group[groups_idx], dms[i][np.ix_(idx_group[groups_idx], idx_group[groups_idx])],
for ao_mask, idx, weight in zip(ao_mask_group, idx_group, weight_group):
p1 = p0 + weight.size
rho_tot[i,:,p0:p1] = eval_rho(mol, ao_mask, dms[i][np.ix_(idx, idx)],
xctype=xctype, hermi=1, with_lapl=with_lapl)
p0 = p1
else:
mo_coeff_mask_group = [mo_coeff[idx,:] for idx in idx_group]
# TODO: create eval_rho5 for grouped ao_mask
eval_rho5_res = eval_rho5(mol, ao_mask_group, mo_coeff_mask_group, mo_occ, None, xctype, with_lapl)
for groups_idx in range(groups):
p1 = p0 + weight_group[groups_idx].size
rho_tot[i,:,p0:p1] = eval_rho5_res[groups_idx]
eval_rho5_res_group = eval_rho5(mol, ao_mask_group, mo_coeff_mask_group, mo_occ, None, xctype, with_lapl)
for weight, eval_rho5_res in zip(weight_group, eval_rho5_res_group):
p1 = p0 + weight.size
rho_tot[i,:,p0:p1] = eval_rho5_res
p0 = p1
t1 = log.timer_debug2('eval rho slice', *t1)
t0 = log.timer_debug1('eval rho', *t0)
Expand Down Expand Up @@ -791,44 +790,43 @@ def nr_rks_group(ni, mol, grids, xc_code, dms, relativity=0, hermi=1,
p0 = p0_raw
if xctype == 'LDA':
aow_group = []
for groups_idx in range(groups):
p1 = p0 + weight_group[groups_idx].size
aow = _scale_ao(ao_mask_group[groups_idx], wv[i][0,p0:p1])
for weight, ao_mask in zip(weight_group, ao_mask_group):
p1 = p0 + weight.size
aow = _scale_ao(ao_mask, wv[i][0,p0:p1])
p0 = p1
aow_group.append(aow)
dot_res_group = grouped_dot(ao_mask_group, aow_group)
for groups_idx in range(groups):
add_sparse(vmat[i], dot_res_group[groups_idx], idx_group[groups_idx])
for dot_res, idx in zip(dot_res_group, idx_group):
add_sparse(vmat[i], dot_res, idx)
elif xctype == 'GGA':
aow_group = []
ao_mask_0_group = []
for groups_idx in range(groups):
p1 = p0 + weight_group[groups_idx].size
aow = _scale_ao(ao_mask_group[groups_idx], wv[i][:,p0:p1])
for weight, ao_mask in zip(weight_group, ao_mask_group):
p1 = p0 + weight.size
aow = _scale_ao(ao_mask, wv[i][:,p0:p1])
p0 = p1
aow_group.append(aow)
ao_mask_0_group.append(ao_mask_group[groups_idx][0])
ao_mask_0_group.append(ao_mask[0])
dot_res_group = grouped_dot(ao_mask_0_group, aow_group)
for groups_idx in range(groups):
add_sparse(vmat[i], dot_res_group[groups_idx], idx_group[groups_idx])
for dot_res, idx in zip(dot_res_group, idx_group):
add_sparse(vmat[i], dot_res, idx)
elif xctype == 'NLC':
raise NotImplementedError('NLC')
elif xctype == 'MGGA':
aow_group = []
ao_mask_0_group = []
p0_raw = p0
for groups_idx in range(groups):
p1 = p0 + weight_group[groups_idx].size
aow = _scale_ao(ao_mask_group[groups_idx], wv[i][:4,p0:p1])
p0_tmp = p0
for weight, ao_mask in zip(weight_group, ao_mask_group):
p1 = p0 + weight.size
aow = _scale_ao(ao_mask, wv[i][:4,p0:p1])
p0 = p1
aow_group.append(aow)
ao_mask_0_group.append(ao_mask_group[groups_idx][0])
ao_mask_0_group.append(ao_mask[0])
dot_res_group = grouped_dot(ao_mask_0_group, aow_group)
p0 = p0_raw
for groups_idx in range(groups):
p1 = p0 + weight_group[groups_idx].size
add_sparse(vmat[i], dot_res_group[groups_idx] + _tau_dot(ao_mask_group[groups_idx], ao_mask_group[groups_idx], wv[i][4,p0:p1]),
idx_group[groups_idx])
p0 = p0_tmp
for weight, dot_res, ao_mask, idx in zip(weight_group, dot_res_group, ao_mask_group, idx_group):
p1 = p0 + weight.size
add_sparse(vmat[i], dot_res + _tau_dot(ao_mask, ao_mask, wv[i][4,p0:p1]), idx)
p0 = p1
elif xctype == 'HF':
pass
Expand Down

0 comments on commit 68286e1

Please sign in to comment.