Skip to content

Commit

Permalink
Improve Krylov subspace solver in CPHF (pyscf#156)
Browse files Browse the repository at this point in the history
* improve df scheme

* separate krylov subspace in cphf

* sg1_prune

* block krylov subspace

* debug -> info

* consistent linear dependence thresholds for df

* df inv -> solve

* df inv -> solve in uhf

* bugfix

* bugfix

* dynamic slicing in df hessian

* reduce the computational cost in df Hessian

* bugfix

* raise error if CPHF does not converge

* vectorize krylov subspace

* conv_tol_cpscf = 1e-6

* ucphf

* unit test
  • Loading branch information
wxj6000 authored May 25, 2024
1 parent 67ac50b commit 761f85f
Show file tree
Hide file tree
Showing 16 changed files with 229 additions and 156 deletions.
2 changes: 1 addition & 1 deletion examples/00-h2o.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
atom=atom, # water molecule
basis='def2-tzvpp', # basis set
output='./pyscf.log', # save log file
verbose=1 # control the level of print info
verbose=6 # control the level of print info
)

mf_GPU = rks.RKS( # restricted Kohn-Sham DFT
Expand Down
4 changes: 2 additions & 2 deletions examples/dft_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@
basis=bas,
max_memory=32000)
# set verbose >= 6 for debugging timer
mol.verbose = 4
mol.verbose = 6

if args.unrestricted:
mf_df = uks.UKS(mol, xc=args.xc).density_fit(auxbasis=args.auxbasis)
else:
mf_df = rks.RKS(mol, xc=args.xc).density_fit(auxbasis=args.auxbasis)
mf_df.verbose = 4
mf_df.verbose = 6

if args.solvent:
mf_df = mf_df.PCM()
Expand Down
6 changes: 4 additions & 2 deletions gpu4pyscf/df/df.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@

MIN_BLK_SIZE = getattr(__config__, 'min_ao_blksize', 128)
ALIGNED = getattr(__config__, 'ao_aligned', 32)
LINEAR_DEP_TOL = incore.LINEAR_DEP_THR

# TODO: reuse the setting in pyscf 2.6
LINEAR_DEP_THR = 1e-6#incore.LINEAR_DEP_THR

class DF(lib.StreamObject):
from gpu4pyscf.lib.utils import to_gpu, device
Expand Down Expand Up @@ -93,7 +95,7 @@ def build(self, direct_scf_tol=1e-14, omega=None):
self.cd_low = tag_array(self.cd_low, tag='cd')
except Exception:
w, v = cupy.linalg.eigh(j2c)
idx = w > LINEAR_DEP_TOL
idx = w > LINEAR_DEP_THR
self.cd_low = (v[:,idx] / cupy.sqrt(w[idx]))
self.cd_low = tag_array(self.cd_low, tag='eig')

Expand Down
26 changes: 24 additions & 2 deletions gpu4pyscf/df/grad/rhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,39 @@
import cupy
from cupyx.scipy.linalg import solve_triangular
from pyscf import scf
from gpu4pyscf.df import int3c2e
from gpu4pyscf.lib.cupy_helper import print_mem_info, tag_array, unpack_tril, contract, load_library, take_last2d
from gpu4pyscf.df import int3c2e, df
from gpu4pyscf.lib.cupy_helper import (print_mem_info, tag_array,
unpack_tril, contract, load_library, take_last2d, cholesky)
from gpu4pyscf.grad import rhf as rhf_grad
from gpu4pyscf import __config__
from gpu4pyscf.lib import logger

libcupy_helper = load_library('libcupy_helper')

LINEAR_DEP_THRESHOLD = df.LINEAR_DEP_THR
MIN_BLK_SIZE = getattr(__config__, 'min_ao_blksize', 128)
ALIGNED = getattr(__config__, 'ao_aligned', 64)

def _gen_metric_solver(int2c, decompose_j2c='CD', lindep=LINEAR_DEP_THRESHOLD):
''' generate a solver to solve Ax = b, RHS must be in (n,....) '''
if decompose_j2c.upper() == 'CD':
try:
j2c = cholesky(int2c, lower=True)
def j2c_solver(v):
return solve_triangular(j2c, v, overwrite_b=False)
return j2c_solver

except Exception:
pass

w, v = cupy.linalg.eigh(int2c)
mask = w > lindep
v1 = v[:,mask]
j2c = cupy.dot(v1/w[mask], v1.conj().T)
def j2c_solver(b): # noqa: F811
return j2c.dot(b.reshape(j2c.shape[0],-1)).reshape(b.shape)
return j2c_solver

def get_jk(mf_grad, mol=None, dm0=None, hermi=0, with_j=True, with_k=True, omega=None):
if mol is None: mol = mf_grad.mol
#TODO: dm has to be the SCF density matrix in this version. dm should be
Expand Down
86 changes: 50 additions & 36 deletions gpu4pyscf/df/hessian/rhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,23 @@
'''



import numpy
import cupy
import numpy as np
from pyscf import lib, df
from pyscf.df.grad.rhf import LINEAR_DEP_THRESHOLD
from pyscf import lib
from pyscf.df.incore import LINEAR_DEP_THR
from gpu4pyscf.grad import rhf as rhf_grad
from gpu4pyscf.hessian import rhf as rhf_hess
from gpu4pyscf.lib.cupy_helper import (
contract, tag_array, release_gpu_stack, print_mem_info, take_last2d, pinv)
from gpu4pyscf.df import int3c2e
contract, tag_array, get_avail_mem, release_gpu_stack, print_mem_info, take_last2d, pinv)
from gpu4pyscf.df import int3c2e, df
from gpu4pyscf.lib import logger
from gpu4pyscf import __config__
from gpu4pyscf.df.grad.rhf import _gen_metric_solver

LINEAR_DEP_THRESHOLD = df.LINEAR_DEP_THR
BLKSIZE = 256
ALIGNED = getattr(__config__, 'ao_aligned', 32)

def partial_hess_elec(hessobj, mo_energy=None, mo_coeff=None, mo_occ=None,
atmlst=None, max_memory=4000, verbose=None):
Expand Down Expand Up @@ -101,6 +104,7 @@ def _partial_hess_ejk(hessobj, mo_energy=None, mo_coeff=None, mo_occ=None,
int2c = cupy.asarray(int2c, order='C')
int2c = take_last2d(int2c, aux_ao_idx)
int2c_inv = pinv(int2c, lindep=LINEAR_DEP_THRESHOLD)
solve_j2c = _gen_metric_solver(int2c)
int2c = None

int2c_ip1 = cupy.asarray(int2c_ip1, order='C')
Expand All @@ -114,8 +118,8 @@ def _partial_hess_ejk(hessobj, mo_energy=None, mo_coeff=None, mo_occ=None,

# int3c contributions
wj, wk_P__ = int3c2e.get_int3c2e_jk(mol, auxmol, dm0_tag, omega=omega)
rhoj0_P = contract('pq,q->p', int2c_inv, wj)
rhok0_P__ = contract('pq,qij->pij', int2c_inv, wk_P__)
rhoj0_P = solve_j2c(wj)
rhok0_P__ = solve_j2c(wk_P__)
wj = wk_P__ = None
t1 = log.timer_debug1('intermediate variables with int3c2e', *t1)

Expand All @@ -125,56 +129,64 @@ def _partial_hess_ejk(hessobj, mo_energy=None, mo_coeff=None, mo_occ=None,

# int3c_ip1 contributions
wj1_P, wk1_Pko = int3c2e.get_int3c2e_ip1_wjk(intopt, dm0_tag, omega=omega)
rhoj1_P = contract('pq,ipx->iqx', int2c_inv, wj1_P)
#rhoj1_P = contract('pq,pix->qix', int2c_inv, wj1_P)
rhoj1_P = solve_j2c(wj1_P)

hj_ao_ao += 4.0*contract('ipx,jpy->ijxy', rhoj1_P, wj1_P) # (10|0)(0|0)(0|01)
hj_ao_ao += 4.0*contract('pix,pjy->ijxy', rhoj1_P, wj1_P) # (10|0)(0|0)(0|01)
wj1_P = None
if hessobj.auxbasis_response:
wj0_01 = contract('ypq,q->yp', int2c_ip1, rhoj0_P)
wj1_01 = contract('yqp,ipx->iqxy', int2c_ip1, rhoj1_P)
hj_ao_aux += contract('ipx,py->ipxy', rhoj1_P, wj_ip2) # (10|0)(1|00)
hj_ao_aux -= contract('ipx,yp->ipxy', rhoj1_P, wj0_01) # (10|0)(1|0)(0|00)
wj1_01 = contract('yqp,pix->iqxy', int2c_ip1, rhoj1_P)
hj_ao_aux += contract('pix,py->ipxy', rhoj1_P, wj_ip2) # (10|0)(1|00)
hj_ao_aux -= contract('pix,yp->ipxy', rhoj1_P, wj0_01) # (10|0)(1|0)(0|00)
hj_ao_aux -= contract('q,iqxy->iqxy', rhoj0_P, wj1_01) # (10|0)(0|1)(0|00)
wj1_01 = None
rhoj1_P = None

int2c_ip1_inv = contract('yqp,pr->yqr', int2c_ip1, int2c_inv)
if with_k:
for i0, i1 in lib.prange(0,nao,64):
wk1_Pko_islice = cupy.asarray(wk1_Pko[i0:i1])
rhok1_Pko = contract('pq,iqox->ipox', int2c_inv, wk1_Pko_islice)
for k0, k1 in lib.prange(0,nao,64):
wk1_Pko_kslice = cupy.asarray(wk1_Pko[k0:k1])
mem_avail = get_avail_mem()
nocc = mocc.shape[1]
slice_size = naux*nocc*9 # largest slice of intermediate variables
blksize = int(mem_avail*0.2/8/slice_size/ALIGNED) * ALIGNED
for i0, i1 in lib.prange(0,nao,blksize):
wk1_Pko_islice = cupy.asarray(wk1_Pko[:,i0:i1])
#rhok1_Pko = contract('pq,qiox->piox', int2c_inv, wk1_Pko_islice)
rhok1_Pko = solve_j2c(wk1_Pko_islice)
wk1_Pko_islice = None
for k0, k1 in lib.prange(0,nao,blksize):
wk1_Pko_kslice = cupy.asarray(wk1_Pko[:,k0:k1])

# (10|0)(0|10) without response of RI basis
vk2_ip1_ip1 = contract('ipox,kpoy->ikxy', rhok1_Pko, wk1_Pko_kslice)
vk2_ip1_ip1 = contract('piox,pkoy->ikxy', rhok1_Pko, wk1_Pko_kslice)
hk_ao_ao[i0:i1,k0:k1] += contract('ikxy,ik->ikxy', vk2_ip1_ip1, dm0[i0:i1,k0:k1])
vk2_ip1_ip1 = None

# (10|0)(0|01) without response of RI basis
bra = contract('ipox,ko->ipkx', rhok1_Pko, mocc_2[k0:k1])
ket = contract('kpoy,io->kpiy', wk1_Pko_kslice, mocc_2[i0:i1])
hk_ao_ao[i0:i1,k0:k1] += contract('ipkx,kpiy->ikxy', bra, ket)
bra = contract('piox,ko->pikx', rhok1_Pko, mocc_2[k0:k1])
ket = contract('pkoy,io->pkiy', wk1_Pko_kslice, mocc_2[i0:i1])
hk_ao_ao[i0:i1,k0:k1] += contract('pikx,pkiy->ikxy', bra, ket)
bra = ket = None
wk1_Pko_kslice = None
if hessobj.auxbasis_response:
# (10|0)(1|00)
wk_ip2_Ipo = contract('porx,io->iprx', wk_ip2_P__, mocc_2[i0:i1])
hk_ao_aux[i0:i1] += contract('ipox,ipoy->ipxy', rhok1_Pko, wk_ip2_Ipo)
wk_ip2_Ipo = contract('porx,io->pirx', wk_ip2_P__, mocc_2[i0:i1])
hk_ao_aux[i0:i1] += contract('piox,pioy->ipxy', rhok1_Pko, wk_ip2_Ipo)
wk_ip2_Ipo = None

# (10|0)(1|0)(0|00)
rhok0_P_I = contract('qor,ir->qoi', rhok0_P__, mocc_2[i0:i1])
wk1_P_I = contract('ypq,qoi->ipoy', int2c_ip1, rhok0_P_I)
hk_ao_aux[i0:i1] -= contract("ipox,ipoy->ipxy", rhok1_Pko, wk1_P_I)
wk1_P_I = rhok1_Pko = None
wk1_P_I = contract('ypq,qoi->pioy', int2c_ip1, rhok0_P_I)
hk_ao_aux[i0:i1] -= contract("piox,pioy->ipxy", rhok1_Pko, wk1_P_I)
wk1_P_I = None

# (10|0)(0|1)(0|00)
for q0,q1 in lib.prange(0,naux,64):
wk1_I = contract('yqp,ipox->iqoxy', int2c_ip1_inv[:,q0:q1], wk1_Pko_islice)
hk_ao_aux[i0:i1,q0:q1] -= contract('qoi,iqoxy->iqxy', rhok0_P_I[q0:q1], wk1_I)
#for q0,q1 in lib.prange(0,naux,64):
wk1_I = contract('yqp,piox->qioxy', int2c_ip1, rhok1_Pko)
hk_ao_aux[i0:i1] -= contract('qoi,qioxy->iqxy', rhok0_P_I, wk1_I)
#wk1_I = contract('piox,qoi->ipqx', rhok1_Pko, rhok0_P_I)
#hk_ao_aux[i0:i1] -= contract('ipqx,yqp->iqxy', wk1_I, int2c_ip1)
wk1_I = rhok0_P_I = None
wk1_Pko_islice = None

wk1_Pko = None
t1 = log.timer_debug1('intermediate variables with int3c2e_ip1', *t1)

Expand Down Expand Up @@ -249,6 +261,7 @@ def _partial_hess_ejk(hessobj, mo_energy=None, mo_coeff=None, mo_occ=None,
# aux-aux pair
if hessobj.auxbasis_response > 1:
wj0_10 = contract('ypq,p->ypq', int2c_ip1, rhoj0_P)
int2c_ip1_inv = contract('yqp,pr->yqr', int2c_ip1, int2c_inv)

rhoj0_10 = contract('p,xpq->xpq', rhoj0_P, int2c_ip1_inv) # (1|0)(0|00)
hj_aux_aux += .5 * contract('xpr,yqr->pqxy', rhoj0_10, wj0_10) # (00|0)(1|0), (0|1)(0|00)
Expand Down Expand Up @@ -448,21 +461,22 @@ def _gen_jk(hessobj, mo_coeff, mo_occ, chkfile=None, atmlst=None,
dm0_tag = tag_array(dm0, occ_coeff=mocc)

int2c = take_last2d(int2c, aux_ao_idx)
int2c_inv = pinv(int2c, lindep=LINEAR_DEP_THRESHOLD)
solve_j2c = _gen_metric_solver(int2c)
int2c = None

wj, wk_Pl_ = int3c2e.get_int3c2e_wjk(mol, auxmol, dm0_tag, omega=omega)
rhoj0 = contract('pq,q->p', int2c_inv, wj)
rhoj0 = solve_j2c(wj)

wj = None
if isinstance(wk_Pl_, cupy.ndarray):
rhok0_Pl_ = contract('pq,qio->pio', int2c_inv, wk_Pl_)
rhok0_Pl_ = solve_j2c(wk_Pl_)
else:
rhok0_Pl_ = np.empty_like(wk_Pl_)
for p0, p1 in lib.prange(0,nao,64):
wk_tmp = cupy.asarray(wk_Pl_[:,p0:p1])
rhok0_Pl_[:,p0:p1] = contract('pq,qio->pio', int2c_inv, wk_tmp).get()
rhok0_Pl_[:,p0:p1] = solve_j2c(wk_tmp).get()
wk_tmp = None
wk_Pl_ = int2c_inv = None
wk_Pl_ = None

# -----------------------------
# int3c_ip1 contributions
Expand Down
Loading

0 comments on commit 761f85f

Please sign in to comment.