From ed64bbce9cc0117d52be3c4efda2c7ff7809ba81 Mon Sep 17 00:00:00 2001 From: Xiaojie Wu Date: Wed, 3 Jan 2024 17:43:12 -0800 Subject: [PATCH 01/10] Bugfix transpose sum (#73) * fixed a bug in screen_index * added unit test for to_gpu * new grids group scheme * use grid_aligned in gpu4pyscf.__config__ * fixed a bug in eval_ao * fixed a bug in transpose_sum * remove print --- gpu4pyscf/__config__.py | 6 +-- gpu4pyscf/__init__.py | 2 +- gpu4pyscf/df/df.py | 5 ++- gpu4pyscf/df/df_jk.py | 5 +-- gpu4pyscf/dft/gen_grid.py | 21 +--------- gpu4pyscf/dft/tests/test_ao_values.py | 1 - gpu4pyscf/dft/tests/test_grids.py | 2 +- gpu4pyscf/lib/cupy_helper.py | 4 +- gpu4pyscf/lib/cupy_helper/transpose.cu | 53 ++++++++++++++----------- gpu4pyscf/lib/gdft/gen_grids.cu | 1 - gpu4pyscf/lib/gdft/nr_eval_gto.cu | 52 ++++++++++++------------ gpu4pyscf/lib/gdft/vv10.cu | 20 +++------- gpu4pyscf/lib/tests/test_cupy_helper.py | 4 +- 13 files changed, 76 insertions(+), 100 deletions(-) diff --git a/gpu4pyscf/__config__.py b/gpu4pyscf/__config__.py index 987ca87c..5b740207 100644 --- a/gpu4pyscf/__config__.py +++ b/gpu4pyscf/__config__.py @@ -13,7 +13,7 @@ # such as V100-32G elif props['totalGlobalMem'] >= 32 * GB: min_ao_blksize = 128 - min_grid_blksize = 256*256#128*128 + min_grid_blksize = 128*128 ao_aligned = 32 grid_aligned = 128 mem_fraction = 0.9 @@ -21,7 +21,7 @@ # such as A30-24GB elif props['totalGlobalMem'] >= 16 * GB: min_ao_blksize = 128 - min_grid_blksize = 256*256#128*128 + min_grid_blksize = 128*128 ao_aligned = 32 grid_aligned = 128 mem_fraction = 0.9 @@ -35,4 +35,4 @@ mem_fraction = 0.9 number_of_threads = 1024 * 80 -cupy.get_default_memory_pool().set_limit(fraction=mem_fraction) \ No newline at end of file +cupy.get_default_memory_pool().set_limit(fraction=mem_fraction) diff --git a/gpu4pyscf/__init__.py b/gpu4pyscf/__init__.py index 004dfb05..5fd45d2f 100644 --- a/gpu4pyscf/__init__.py +++ b/gpu4pyscf/__init__.py @@ -1,5 +1,5 @@ from . import lib, grad, hessian, solvent, scf, dft -__version__ = '0.6.13' +__version__ = '0.6.14' # monkey patch libxc reference due to a bug in nvcc from pyscf.dft import libxc diff --git a/gpu4pyscf/df/df.py b/gpu4pyscf/df/df.py index 0b40d26e..ff3c6877 100644 --- a/gpu4pyscf/df/df.py +++ b/gpu4pyscf/df/df.py @@ -22,7 +22,7 @@ from pyscf import lib from pyscf.df import df, addons from gpu4pyscf.lib.cupy_helper import ( - cholesky, tag_array, get_avail_mem, cart2sph, take_last2d) + cholesky, tag_array, get_avail_mem, cart2sph, take_last2d, transpose_sum) from gpu4pyscf.df import int3c2e, df_jk from gpu4pyscf.lib import logger from gpu4pyscf import __config__ @@ -262,7 +262,8 @@ def cholesky_eri_gpu(intopt, mol, auxmol, cd_low, omega=None, sr_only=False): row = intopt.ao_pairs_row[cp_ij_id] - i0 col = intopt.ao_pairs_col[cp_ij_id] - j0 if cpi == cpj: - ints_slices = ints_slices + ints_slices.transpose([0,2,1]) + #ints_slices = ints_slices + ints_slices.transpose([0,2,1]) + transpose_sum(ints_slices) ints_slices = ints_slices[:,col,row] if cd_low.tag == 'eig': diff --git a/gpu4pyscf/df/df_jk.py b/gpu4pyscf/df/df_jk.py index 23a46850..ff0cbd7e 100644 --- a/gpu4pyscf/df/df_jk.py +++ b/gpu4pyscf/df/df_jk.py @@ -290,15 +290,14 @@ def get_jk(dfobj, dms_tag, hermi=1, with_j=True, with_k=True, direct_scf_tol=1e- vj_packed += cupy.dot(rhoj, cderi_sparse.T) if with_k: rhok = contract('Lij,jk->Lki', cderi, occ_coeff) - #vk[0] += contract('Lki,Lkj->ij', rhok, rhok) - cublas.syrk('T', rhok.reshape([-1,nao]), out=vk[0], alpha=1.0, beta=1.0, lower=True) + #vk[0] += 2.0 * contract('Lki,Lkj->ij', rhok, rhok) + cublas.syrk('T', rhok.reshape([-1,nao]), out=vk[0], alpha=2.0, beta=1.0, lower=True) if with_j: vj[:,rows,cols] = vj_packed vj[:,cols,rows] = vj_packed if with_k: vk[0][numpy.diag_indices(nao)] *= 0.5 transpose_sum(vk) - vk *= 2.0 # CP-HF K matrix elif hasattr(dms_tag, 'mo1'): if with_j: diff --git a/gpu4pyscf/dft/gen_grid.py b/gpu4pyscf/dft/gen_grid.py index 40c6947a..5ced5d94 100644 --- a/gpu4pyscf/dft/gen_grid.py +++ b/gpu4pyscf/dft/gen_grid.py @@ -186,27 +186,8 @@ def gen_grids_partition(atm_coords, coords, a): natm = atm_coords.shape[0] ngrids = coords.shape[0] assert ngrids < 65535 * 16 - #x_i = cupy.expand_dims(atm_coords, axis=1) - #x_g = cupy.expand_dims(coords, axis=0) - #squared_diff = (x_i - x_g)**2 - #dist_ig = cupy.sum(squared_diff, axis=2)**0.5 - #x_j = cupy.expand_dims(atm_coords, axis=0) - #squared_diff = (x_i - x_j)**2 - #dist_ij = cupy.sum(squared_diff, axis=2)**0.5 - - pbecke = cupy.ones([natm, ngrids], order='C') - ''' - err = libgdft.GDFTgen_grid_partition( - ctypes.cast(stream.ptr, ctypes.c_void_p), - ctypes.cast(pbecke.data.ptr, ctypes.c_void_p), - ctypes.cast(dist_ig.data.ptr, ctypes.c_void_p), - ctypes.cast(dist_ij.data.ptr, ctypes.c_void_p), - ctypes.cast(a.data.ptr, ctypes.c_void_p), - ctypes.c_int(ngrids), - ctypes.c_int(natm) - ) - ''' + pbecke = cupy.empty([natm, ngrids], order='C') atm_coords = cupy.asarray(atm_coords, order='F') err = libgdft.GDFTgen_grid_partition( ctypes.cast(stream.ptr, ctypes.c_void_p), diff --git a/gpu4pyscf/dft/tests/test_ao_values.py b/gpu4pyscf/dft/tests/test_ao_values.py index 78ae6c98..5ee875b6 100644 --- a/gpu4pyscf/dft/tests/test_ao_values.py +++ b/gpu4pyscf/dft/tests/test_ao_values.py @@ -73,7 +73,6 @@ def test_ao_sph_deriv2(self): ao_cpu = cupy.asarray(ao) ni = NumInt(xc='LDA') ao_gpu = numint.eval_ao(ni, mol_sph, coords, deriv=2) - #idx = cupy.argwhere(cupy.abs(ao_gpu - ao_cpu) > 1e-10) assert cupy.linalg.norm(ao_cpu - ao_gpu) < 1e-8 def test_ao_sph_deriv3(self): diff --git a/gpu4pyscf/dft/tests/test_grids.py b/gpu4pyscf/dft/tests/test_grids.py index 534ece58..d675393b 100644 --- a/gpu4pyscf/dft/tests/test_grids.py +++ b/gpu4pyscf/dft/tests/test_grids.py @@ -32,7 +32,7 @@ def setUpModule(): O 0.000000 0.000000 0.117790 H 0.000000 0.755453 -0.471161 H 0.000000 -0.755453 -0.471161''', - basis = 'ccpvdz', + basis = 'ccpvqz', charge = 1, spin = 1, # = 2S = spin_up - spin_down output = '/dev/null') diff --git a/gpu4pyscf/lib/cupy_helper.py b/gpu4pyscf/lib/cupy_helper.py index f52f68f7..cebbbfd3 100644 --- a/gpu4pyscf/lib/cupy_helper.py +++ b/gpu4pyscf/lib/cupy_helper.py @@ -264,9 +264,9 @@ def take_last2d(a, indices, out=None): raise RuntimeError('failed in take_last2d kernel') return out -def transpose_sum(a): +def transpose_sum(a, stream=None): ''' - transpose (0,2,1) + return a + a.transpose(0,2,1) ''' assert a.flags.c_contiguous assert a.ndim == 3 diff --git a/gpu4pyscf/lib/cupy_helper/transpose.cu b/gpu4pyscf/lib/cupy_helper/transpose.cu index 7e927059..78b3dd39 100644 --- a/gpu4pyscf/lib/cupy_helper/transpose.cu +++ b/gpu4pyscf/lib/cupy_helper/transpose.cu @@ -35,36 +35,41 @@ static void _dsymm_triu(double *a, int n) a[off + j * N + i] = a[off + i * N + j]; } -__global__ +__global__ void _transpose_sum(double *a, int n) { + if(blockIdx.x > blockIdx.y){ + return; + } __shared__ double block[BLOCK_DIM][BLOCK_DIM+1]; - - // read the matrix tile into shared memory - // load one element per thread from device memory (idata) and store it - // in transposed order in block[][] - unsigned int xIndex = blockIdx.x * BLOCK_DIM + threadIdx.x; - unsigned int yIndex = blockIdx.y * BLOCK_DIM + threadIdx.y; - unsigned int zIndex = blockIdx.z; - unsigned int off = zIndex * n * n; - if((xIndex < n) && (yIndex < n)) - { - unsigned int index_in = yIndex * n + xIndex + off; - block[threadIdx.y][threadIdx.x] = a[index_in]; - } + unsigned int blockx_off = blockIdx.x * BLOCK_DIM; + unsigned int blocky_off = blockIdx.y * BLOCK_DIM; + unsigned int x0 = blockx_off + threadIdx.x; + unsigned int y0 = blocky_off + threadIdx.y; + unsigned int x1 = blocky_off + threadIdx.x; + unsigned int y1 = blockx_off + threadIdx.y; + unsigned int z = blockIdx.z; + + unsigned int off = n * n * z; + unsigned int xy0 = y0 * n + x0 + off; + unsigned int xy1 = y1 * n + x1 + off; - // synchronise to ensure all writes to block[][] have completed - __syncthreads(); + if (x0 < n && y0 < n){ + block[threadIdx.y][threadIdx.x] = a[xy0]; + } + __syncthreads(); + if (x1 < n && y1 < n){ + block[threadIdx.x][threadIdx.y] += a[xy1]; + } + __syncthreads(); - // write the transposed matrix tile to global memory (odata) in linear order - xIndex = blockIdx.y * BLOCK_DIM + threadIdx.x; - yIndex = blockIdx.x * BLOCK_DIM + threadIdx.y; - if((xIndex < n) && (yIndex < n)) - { - unsigned int index_out = yIndex * n + xIndex + off; - a[index_out] += block[threadIdx.x][threadIdx.y]; - } + if(x0 < n && y0 < n){ + a[xy0] = block[threadIdx.y][threadIdx.x]; + } + if(x1 < n && y1 < n){ + a[xy1] = block[threadIdx.x][threadIdx.y]; + } } extern "C" { diff --git a/gpu4pyscf/lib/gdft/gen_grids.cu b/gpu4pyscf/lib/gdft/gen_grids.cu index 2c484ae3..d71eedf9 100644 --- a/gpu4pyscf/lib/gdft/gen_grids.cu +++ b/gpu4pyscf/lib/gdft/gen_grids.cu @@ -40,7 +40,6 @@ int ngrids, int natm) __shared__ double zj[NATOM_PER_BLOCK]; __shared__ double a_smem[NATOM_PER_BLOCK]; __shared__ double dij_smem[NATOM_PER_BLOCK]; - const int tx = threadIdx.x; for (int atom_i = 0; atom_i < natm; atom_i++){ diff --git a/gpu4pyscf/lib/gdft/nr_eval_gto.cu b/gpu4pyscf/lib/gdft/nr_eval_gto.cu index 0295c0c0..48ff5729 100644 --- a/gpu4pyscf/lib/gdft/nr_eval_gto.cu +++ b/gpu4pyscf/lib/gdft/nr_eval_gto.cu @@ -28,7 +28,7 @@ #include "nr_eval_gto.cuh" #include "contract_rho.cuh" -#define THREADS 128 +#define NG_PER_BLOCK 128 #define LMAX 8 #define GTO_MAX_CART 15 @@ -1121,15 +1121,15 @@ static void _sph_kernel_deriv1(BasOffsets offsets) g12 = ax * ry * ry * rz * rz; g13 = ax * ry * rz * rz * rz; g14 = ax * rz * rz * rz * rz; - gtox[ grid_id] = 2.503342941796704538 * g1 - 2.503342941796704530 * g6 ; + gtox[ grid_id] = 2.503342941796704538 * (g1 - g6) ; gtox[1 *ngrids+grid_id] = 5.310392309339791593 * g4 - 1.770130769779930530 * g11; - gtox[2 *ngrids+grid_id] = 5.677048174545360108 * g8 - 0.946174695757560014 * g1 - 0.946174695757560014 * g6 ; - gtox[3 *ngrids+grid_id] = 2.676186174229156671 * g13 - 2.007139630671867500 * g4 - 2.007139630671867500 * g11; - gtox[4 *ngrids+grid_id] = 0.317356640745612911 * g0 + 0.634713281491225822 * g3 - 2.538853125964903290 * g5 + 0.317356640745612911 * g10 - 2.538853125964903290 * g12 + 0.846284375321634430 * g14; - gtox[5 *ngrids+grid_id] = 2.676186174229156671 * g9 - 2.007139630671867500 * g2 - 2.007139630671867500 * g7 ; - gtox[6 *ngrids+grid_id] = 2.838524087272680054 * g5 + 0.473087347878780009 * g10 - 0.473087347878780002 * g0 - 2.838524087272680050 * g12; + gtox[2 *ngrids+grid_id] = 5.677048174545360108 * g8 - 0.946174695757560014 * (g1 + g6); + gtox[3 *ngrids+grid_id] = 2.676186174229156671 * g13 - 2.007139630671867500 * (g4 + g11); + gtox[4 *ngrids+grid_id] = 0.317356640745612911 * (g0 + g10) + 0.634713281491225822 * g3 - 2.538853125964903290 * (g5 + g12) + 0.846284375321634430 * g14; + gtox[5 *ngrids+grid_id] = 2.676186174229156671 * g9 - 2.007139630671867500 * (g2 + g7); + gtox[6 *ngrids+grid_id] = 2.838524087272680054 * (g5 - g12) + 0.473087347878780009 * (g10 - g0); gtox[7 *ngrids+grid_id] = 1.770130769779930531 * g2 - 5.310392309339791590 * g7 ; - gtox[8 *ngrids+grid_id] = 0.625835735449176134 * g0 - 3.755014412695056800 * g3 + 0.625835735449176134 * g10; + gtox[8 *ngrids+grid_id] = 0.625835735449176134 * (g0 + g10) - 3.755014412695056800 * g3; double ay = ce_2a * ry; g0 = ay * rx * rx * rx * rx; @@ -1147,15 +1147,15 @@ static void _sph_kernel_deriv1(BasOffsets offsets) g12 = (ay * ry + 2 * ce) * ry * rz * rz; g13 = (ay * ry + ce) * rz * rz * rz; g14 = ay * rz * rz * rz * rz; - gtoy[ grid_id] = 2.503342941796704538 * g1 - 2.503342941796704530 * g6 ; + gtoy[ grid_id] = 2.503342941796704538 * (g1 - g6) ; gtoy[1 *ngrids+grid_id] = 5.310392309339791593 * g4 - 1.770130769779930530 * g11; - gtoy[2 *ngrids+grid_id] = 5.677048174545360108 * g8 - 0.946174695757560014 * g1 - 0.946174695757560014 * g6 ; - gtoy[3 *ngrids+grid_id] = 2.676186174229156671 * g13 - 2.007139630671867500 * g4 - 2.007139630671867500 * g11; - gtoy[4 *ngrids+grid_id] = 0.317356640745612911 * g0 + 0.634713281491225822 * g3 - 2.538853125964903290 * g5 + 0.317356640745612911 * g10 - 2.538853125964903290 * g12 + 0.846284375321634430 * g14; - gtoy[5 *ngrids+grid_id] = 2.676186174229156671 * g9 - 2.007139630671867500 * g2 - 2.007139630671867500 * g7 ; - gtoy[6 *ngrids+grid_id] = 2.838524087272680054 * g5 + 0.473087347878780009 * g10 - 0.473087347878780002 * g0 - 2.838524087272680050 * g12; + gtoy[2 *ngrids+grid_id] = 5.677048174545360108 * g8 - 0.946174695757560014 * (g1 + g6); + gtoy[3 *ngrids+grid_id] = 2.676186174229156671 * g13 - 2.007139630671867500 * (g4 + g11); + gtoy[4 *ngrids+grid_id] = 0.317356640745612911 * (g0 + g10) + 0.634713281491225822 * g3 - 2.538853125964903290 * (g5 + g12) + 0.846284375321634430 * g14; + gtoy[5 *ngrids+grid_id] = 2.676186174229156671 * g9 - 2.007139630671867500 * (g2 + g7); + gtoy[6 *ngrids+grid_id] = 2.838524087272680054 * (g5 - g12) + 0.473087347878780009 * (g10 - g0); gtoy[7 *ngrids+grid_id] = 1.770130769779930531 * g2 - 5.310392309339791590 * g7 ; - gtoy[8 *ngrids+grid_id] = 0.625835735449176134 * g0 - 3.755014412695056800 * g3 + 0.625835735449176134 * g10; + gtoy[8 *ngrids+grid_id] = 0.625835735449176134 * (g0 + g10) - 3.755014412695056800 * g3; double az = ce_2a * rz; g0 = az * rx * rx * rx * rx; @@ -1173,15 +1173,15 @@ static void _sph_kernel_deriv1(BasOffsets offsets) g12 = (az * rz + 2 * ce) * ry * ry * rz; g13 = (az * rz + 3 * ce) * ry * rz * rz; g14 = (az * rz + 4 * ce) * rz * rz * rz; - gtoz[ grid_id] = 2.503342941796704538 * g1 - 2.503342941796704530 * g6 ; + gtoz[ grid_id] = 2.503342941796704538 * (g1 - g6) ; gtoz[1 *ngrids+grid_id] = 5.310392309339791593 * g4 - 1.770130769779930530 * g11; - gtoz[2 *ngrids+grid_id] = 5.677048174545360108 * g8 - 0.946174695757560014 * g1 - 0.946174695757560014 * g6 ; - gtoz[3 *ngrids+grid_id] = 2.676186174229156671 * g13 - 2.007139630671867500 * g4 - 2.007139630671867500 * g11; - gtoz[4 *ngrids+grid_id] = 0.317356640745612911 * g0 + 0.634713281491225822 * g3 - 2.538853125964903290 * g5 + 0.317356640745612911 * g10 - 2.538853125964903290 * g12 + 0.846284375321634430 * g14; - gtoz[5 *ngrids+grid_id] = 2.676186174229156671 * g9 - 2.007139630671867500 * g2 - 2.007139630671867500 * g7 ; - gtoz[6 *ngrids+grid_id] = 2.838524087272680054 * g5 + 0.473087347878780009 * g10 - 0.473087347878780002 * g0 - 2.838524087272680050 * g12; + gtoz[2 *ngrids+grid_id] = 5.677048174545360108 * g8 - 0.946174695757560014 * (g1 + g6); + gtoz[3 *ngrids+grid_id] = 2.676186174229156671 * g13 - 2.007139630671867500 * (g4 + g11); + gtoz[4 *ngrids+grid_id] = 0.317356640745612911 * (g0 + g10) + 0.634713281491225822 * g3 - 2.538853125964903290 * (g5 + g12) + 0.846284375321634430 * g14; + gtoz[5 *ngrids+grid_id] = 2.676186174229156671 * g9 - 2.007139630671867500 * (g2 + g7); + gtoz[6 *ngrids+grid_id] = 2.838524087272680054 * (g5 - g12) + 0.473087347878780009 * (g10 - g0); gtoz[7 *ngrids+grid_id] = 1.770130769779930531 * g2 - 5.310392309339791590 * g7 ; - gtoz[8 *ngrids+grid_id] = 0.625835735449176134 * g0 - 3.755014412695056800 * g3 + 0.625835735449176134 * g10; + gtoz[8 *ngrids+grid_id] = 0.625835735449176134 * (g0 + g10) - 3.755014412695056800 * g3; } } @@ -1559,8 +1559,8 @@ int GDFTeval_gto(cudaStream_t stream, double *ao, int deriv, int cart, offsets.bas_indices = bas_indices; offsets.nbas = local_ctr_offsets[nctr]; offsets.nao = nao; - dim3 threads(THREADS); - dim3 blocks((ngrids+THREADS-1)/THREADS); + dim3 threads(NG_PER_BLOCK); + dim3 blocks((ngrids+NG_PER_BLOCK-1)/NG_PER_BLOCK); for (int ictr = 0; ictr < nctr; ++ictr) { int local_ish = local_ctr_offsets[ictr]; @@ -1706,8 +1706,8 @@ int GDFTeval_gto(cudaStream_t stream, double *ao, int deriv, int cart, int GDFTscreen_index(cudaStream_t stream, int *non0shl_idx, double cutoff, double *grids, int ngrids, int *bas_loc, int nbas, int *bas) { - dim3 threads(THREADS); - dim3 blocks((ngrids+THREADS-1)/THREADS); + dim3 threads(NG_PER_BLOCK); + dim3 blocks((ngrids+NG_PER_BLOCK-1)/NG_PER_BLOCK); for (int shl_id = 0; shl_id < nbas; ++shl_id) { int l = bas[ANG_OF+shl_id*BAS_SLOTS]; diff --git a/gpu4pyscf/lib/gdft/vv10.cu b/gpu4pyscf/lib/gdft/vv10.cu index b7be564d..bdcb01da 100644 --- a/gpu4pyscf/lib/gdft/vv10.cu +++ b/gpu4pyscf/lib/gdft/vv10.cu @@ -168,28 +168,20 @@ static void vv10_grad_kernel(double *Fvec, const double *vvcoords, const double __syncthreads(); for (int l = 0, M = min(NG_PER_BLOCK, vvngrids - j); l < M; ++l){ double3 xj_tmp = xj_t[l]; - double pjx = xj_tmp.x; - double pjy = xj_tmp.y; - double pjz = xj_tmp.z; - // about 23 operations for each pair - double DX = pjx - xi; - double DY = pjy - yi; - double DZ = pjz - zi; + double DX = xj_tmp.x - xi; + double DY = xj_tmp.y - yi; + double DZ = xj_tmp.z - zi; double R2 = DX*DX + DY*DY + DZ*DZ; double3 kp_tmp = kp_t[l]; - double Kpj = kp_tmp.x; - double W0pj = kp_tmp.y; - double RpWj = kp_tmp.z; - - double gp = R2*W0pj + Kpj; + double gp = R2*kp_tmp.y + kp_tmp.x; double g = R2*W0i + Ki; double gt = g + gp; double ggp = g * gp; double ggt_gp = gt * ggp; - double T = RpWj / (ggt_gp * ggt_gp); - double Q = T * ((W0i*gp + W0pj*g)*gt + (W0i+W0pj)*ggp); + double T = kp_tmp.z / (ggt_gp * ggt_gp); + double Q = T * ((W0i*gp + kp_tmp.y*g)*gt + (W0i+kp_tmp.y)*ggp); FX += Q * DX; FY += Q * DY; diff --git a/gpu4pyscf/lib/tests/test_cupy_helper.py b/gpu4pyscf/lib/tests/test_cupy_helper.py index 513f7c38..614e0fa3 100644 --- a/gpu4pyscf/lib/tests/test_cupy_helper.py +++ b/gpu4pyscf/lib/tests/test_cupy_helper.py @@ -31,8 +31,8 @@ def test_take_last2d(self): assert(cupy.linalg.norm(a[:,indices][:,:,indices] - b) < 1e-10) def test_transpose_sum(self): - n = 3 - count = 4 + n = 1287 + count = 127 a = cupy.random.rand(count,n,n) b = a + a.transpose(0,2,1) transpose_sum(a) From e19d51e6f1f6f5e22023dc5dac48ca8ecd388822 Mon Sep 17 00:00:00 2001 From: Xiaojie Wu Date: Fri, 5 Jan 2024 15:59:02 -0800 Subject: [PATCH 02/10] Get vxc (#74) * fixed a bug in screen_index * added unit test for to_gpu * new grids group scheme * use grid_aligned in gpu4pyscf.__config__ * fixed a bug in eval_ao * fixed a bug in transpose_sum * remove print * new get_vxc scheme --- examples/dft_driver.py | 1 - gpu4pyscf/dft/numint.py | 80 ++++++--- gpu4pyscf/lib/cupy_helper.py | 4 + gpu4pyscf/lib/cupy_helper/add_sparse.cu | 4 +- gpu4pyscf/lib/cupy_helper/block_diag.cu | 2 +- gpu4pyscf/lib/cupy_helper/take_last2d.cu | 2 +- gpu4pyscf/lib/cupy_helper/transpose.cu | 4 +- gpu4pyscf/lib/gdft/contract_rho.cu | 1 - gpu4pyscf/lib/gdft/nr_eval_gto.cu | 199 +++++++++++++++++------ 9 files changed, 216 insertions(+), 81 deletions(-) diff --git a/examples/dft_driver.py b/examples/dft_driver.py index 2b2c8cd6..80830e6c 100644 --- a/examples/dft_driver.py +++ b/examples/dft_driver.py @@ -34,7 +34,6 @@ basis=bas, max_memory=32000) # set verbose >= 6 for debugging timer - mol.verbose = 4 mf_df = rks.RKS(mol, xc=args.xc).density_fit(auxbasis=args.auxbasis) diff --git a/gpu4pyscf/dft/numint.py b/gpu4pyscf/dft/numint.py index f7758f1c..0e46c611 100644 --- a/gpu4pyscf/dft/numint.py +++ b/gpu4pyscf/dft/numint.py @@ -24,7 +24,8 @@ from pyscf.dft import numint from pyscf.gto.eval_gto import NBINS, CUTOFF, make_screen_index from gpu4pyscf.scf.hf import basis_seg_contraction -from gpu4pyscf.lib.cupy_helper import contract, get_avail_mem, load_library, add_sparse, release_gpu_stack, take_last2d +from gpu4pyscf.lib.cupy_helper import ( + contract, get_avail_mem, load_library, add_sparse, release_gpu_stack, take_last2d, transpose_sum) from gpu4pyscf.dft import xc_deriv, xc_alias, libxc from gpu4pyscf import __config__ from gpu4pyscf.lib import logger @@ -83,11 +84,17 @@ def eval_ao(ni, mol, coords, deriv=0, shls_slice=None, nao_slice=None, ao_loc_sl comp = (deriv+1)*(deriv+2)*(deriv+3)//6 stream = cupy.cuda.get_current_stream() + # ao must be set to zero due to implementation + if deriv > 1: + ao = cupy.zeros((comp, nao_slice, ngrids), order='C') + else: + ao = cupy.empty((comp, nao_slice, ngrids), order='C') + + #ao = cupy.zeros((comp, nao_slice, ngrids), order='C') if not with_opt: # mol may be different to _GDFTOpt.mol. # nao should be consistent with the _GDFTOpt.mol object coeff = cupy.asarray(opt.coeff) - ao = cupy.zeros((comp, nao_slice, ngrids), order='C') with opt.gdft_envs_cache(): err = libgdft.GDFTeval_gto( ctypes.cast(stream.ptr, ctypes.c_void_p), @@ -102,7 +109,6 @@ def eval_ao(ni, mol, coords, deriv=0, shls_slice=None, nao_slice=None, ao_loc_sl mol._bas.ctypes.data_as(ctypes.c_void_p)) ao = contract('nig,ij->njg', ao, coeff).transpose([0,2,1]) else: - ao = cupy.zeros((comp, nao_slice, ngrids), order='C') err = libgdft.GDFTeval_gto( ctypes.cast(stream.ptr, ctypes.c_void_p), ctypes.cast(ao.data.ptr, ctypes.c_void_p), @@ -174,7 +180,7 @@ def eval_rho1(mol, ao, mo_coeff, mo_occ, non0tab=None, xctype='LDA', raise NotImplementedError def eval_rho2(mol, ao, mo_coeff, mo_occ, non0tab=None, xctype='LDA', - with_lapl=True, verbose=None): + with_lapl=True, verbose=None, out=None): xctype = xctype.upper() if xctype == 'LDA' or xctype == 'HF': _, ngrids = ao.shape @@ -467,23 +473,48 @@ def nr_rks(ni, mol, grids, xc_code, dms, relativity=0, hermi=1, ao_deriv = 0 else: ao_deriv = 1 + ngrids = grids.weights.size + if xctype == 'LDA': + rho_tot = cupy.empty([nset,1,ngrids]) + elif xctype == 'GGA': + rho_tot = cupy.empty([nset,4,ngrids]) + else: + rho_tot = cupy.empty([nset,6,ngrids]) + p0 = p1 = 0 for ao_mask, idx, weight, _ in ni.block_loop(mol, grids, nao, ao_deriv): + p1 = p0 + weight.size for i in range(nset): t0 = log.init_timer() if mo_coeff is None: - rho = eval_rho(mol, ao_mask, dms[i][np.ix_(idx,idx)], xctype=xctype, hermi=1) + rho_tot[i,:,p0:p1] = eval_rho(mol, ao_mask, dms[i][np.ix_(idx,idx)], xctype=xctype, hermi=1) else: mo_coeff_mask = mo_coeff[idx,:] - rho = eval_rho2(mol, ao_mask, mo_coeff_mask, mo_occ, None, xctype) - + rho_tot[i,:,p0:p1] = eval_rho2(mol, ao_mask, mo_coeff_mask, mo_occ, None, xctype) t1 = log.timer_debug1('eval rho', *t0) - exc, vxc = ni.eval_xc_eff(xc_code, rho, deriv=1, xctype=xctype)[:2] - vxc = cupy.asarray(vxc, order='C') - exc = cupy.asarray(exc, order='C') - t1 = log.timer_debug1('eval vxc', *t1) + p0 = p1 + + vxc_tot = [] + for i in range(nset): + if xctype == 'LDA': + exc, vxc = ni.eval_xc_eff(xc_code, rho_tot[i][0], deriv=1, xctype=xctype)[:2] + else: + exc, vxc = ni.eval_xc_eff(xc_code, rho_tot[i], deriv=1, xctype=xctype)[:2] + vxc = cupy.asarray(vxc, order='C') + exc = cupy.asarray(exc, order='C') + den = rho_tot[i][0] * grids.weights + nelec[i] = den.sum() + excsum[i] = cupy.sum(den * exc[:,0]) + vxc_tot.append(vxc) + t1 = log.timer_debug1('eval vxc', *t1) + + p0 = p1 = 0 + for ao_mask, idx, weight, _ in ni.block_loop(mol, grids, nao, ao_deriv): + p1 = p0 + weight.size + for i in range(nset): + vxc = vxc_tot[i][:,p0:p1] if xctype == 'LDA': - den = rho * weight + #den = rho * weight wv = weight * vxc[0] ''' if USE_SPARSITY == 0: @@ -499,7 +530,7 @@ def nr_rks(ni, mol, grids, xc_code, dms, relativity=0, hermi=1, else: raise NotImplementedError(f'USE_SPARSITY = {USE_SPARSITY} is not implemented') elif xctype == 'GGA': - den = rho[0] * weight + #den = rho[0] * weight wv = vxc * weight wv[0] *= .5 ''' @@ -512,14 +543,13 @@ def nr_rks(ni, mol, grids, xc_code, dms, relativity=0, hermi=1, ''' if USE_SPARSITY == 2: aow = _scale_ao(ao_mask, wv) - #vmat[i][cupy.ix_(mask, mask)] += ao_mask[0].dot(aow.T) add_sparse(vmat[i], ao_mask[0].dot(aow.T), idx) else: raise NotImplementedError(f'USE_SPARSITY = {USE_SPARSITY} is not implemented') elif xctype == 'NLC': raise NotImplementedError('NLC') elif xctype == 'MGGA': - den = rho[0] * weight + #den = rho[0] * weight wv = vxc * weight wv[[0, 4]] *= .5 # *.5 for v+v.T ''' @@ -545,9 +575,8 @@ def nr_rks(ni, mol, grids, xc_code, dms, relativity=0, hermi=1, pass else: raise NotImplementedError(f'numint.nr_rks for functional {xc_code}') - nelec[i] += den.sum() - excsum[i] += cupy.dot(den, exc)[0] t1 = log.timer_debug1('integration', *t1) + p0 = p1 vmat = contract('pi,npq->niq', coeff, vmat) vmat = contract('qj,niq->nij', coeff, vmat) @@ -555,8 +584,8 @@ def nr_rks(ni, mol, grids, xc_code, dms, relativity=0, hermi=1, #vmat = take_last2d(vmat, rev_ao_idx) if xctype != 'LDA': - #transpose_sum(vmat) vmat = vmat + vmat.transpose([0,2,1]) + #transpose_sum(vmat) if FREE_CUPY_CACHE: dms = None @@ -567,7 +596,7 @@ def nr_rks(ni, mol, grids, xc_code, dms, relativity=0, hermi=1, excsum = excsum[0] vmat = vmat[0] - return nelec, excsum, vmat#np.asarray(vmat) + return nelec, excsum, vmat def nr_uks(ni, mol, grids, xc_code, dms, relativity=0, hermi=1, max_memory=2000, verbose=None): @@ -1286,13 +1315,12 @@ def _block_loop(ni, mol, grids, nao=None, deriv=0, max_memory=2000, zero_idx = cupy.asarray(zero_idx, dtype=np.int32) pad = (len(idx) + AO_ALIGNMENT - 1) // AO_ALIGNMENT * AO_ALIGNMENT - len(idx) idx = cupy.hstack([idx, zero_idx[:pad]]) + pad = min(pad, len(zero_idx)) non0shl_idx = cupy.asarray(np.where(non0shl_idx)[0], dtype=np.int32) - - ni.non0ao_idx[deriv, block_id, blksize, ngrids] = (idx, non0shl_idx, ctr_offsets_slice, ao_loc_slice) + ni.non0ao_idx[deriv, block_id, blksize, ngrids] = (pad, idx, non0shl_idx, ctr_offsets_slice, ao_loc_slice) log.timer_debug1('init ao sparsity', *t0) else: - idx, non0shl_idx, ctr_offsets_slice, ao_loc_slice = ni.non0ao_idx[deriv, block_id, blksize, ngrids] - + pad, idx, non0shl_idx, ctr_offsets_slice, ao_loc_slice = ni.non0ao_idx[deriv, block_id, blksize, ngrids] t0 = log.init_timer() ao_mask = eval_ao( ni, mol, coords, deriv, @@ -1300,7 +1328,11 @@ def _block_loop(ni, mol, grids, nao=None, deriv=0, max_memory=2000, shls_slice=non0shl_idx, ao_loc_slice=ao_loc_slice, ctr_offsets_slice=ctr_offsets_slice) - + if pad > 0: + if deriv == 0: + ao_mask[-pad:,:] = 0.0 + else: + ao_mask[:,-pad:,:] = 0.0 block_id += 1 log.timer_debug1('evaluate ao slice', *t0) yield ao_mask, idx, weight, coords diff --git a/gpu4pyscf/lib/cupy_helper.py b/gpu4pyscf/lib/cupy_helper.py index cebbbfd3..3d0d25f5 100644 --- a/gpu4pyscf/lib/cupy_helper.py +++ b/gpu4pyscf/lib/cupy_helper.py @@ -154,7 +154,9 @@ def add_sparse(a, b, indices): count = 1 else: raise RuntimeError('add_sparse only supports 2d or 3d tensor') + stream = cupy.cuda.get_current_stream() err = libcupy_helper.add_sparse( + ctypes.cast(stream.ptr, ctypes.c_void_p), ctypes.cast(a.data.ptr, ctypes.c_void_p), ctypes.cast(b.data.ptr, ctypes.c_void_p), ctypes.cast(indices.data.ptr, ctypes.c_void_p), @@ -272,7 +274,9 @@ def transpose_sum(a, stream=None): assert a.ndim == 3 n = a.shape[-1] count = a.shape[0] + stream = cupy.cuda.get_current_stream() err = libcupy_helper.transpose_sum( + ctypes.cast(stream.ptr, ctypes.c_void_p), ctypes.cast(a.data.ptr, ctypes.c_void_p), ctypes.c_int(n), ctypes.c_int(count) diff --git a/gpu4pyscf/lib/cupy_helper/add_sparse.cu b/gpu4pyscf/lib/cupy_helper/add_sparse.cu index d8033015..edccf7e1 100644 --- a/gpu4pyscf/lib/cupy_helper/add_sparse.cu +++ b/gpu4pyscf/lib/cupy_helper/add_sparse.cu @@ -39,11 +39,11 @@ void _add_sparse(double *a, double *b, int *indices, int n, int m, int count) extern "C" { __host__ -int add_sparse(double *a, double *b, int *indices, int n, int m, int count){ +int add_sparse(cudaStream_t stream, double *a, double *b, int *indices, int n, int m, int count){ int ntile = (m + THREADS - 1) / THREADS; dim3 threads(THREADS, THREADS); dim3 blocks(ntile, ntile); - _add_sparse<<>>(a, b, indices, n, m, count); + _add_sparse<<>>(a, b, indices, n, m, count); cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { return 1; diff --git a/gpu4pyscf/lib/cupy_helper/block_diag.cu b/gpu4pyscf/lib/cupy_helper/block_diag.cu index f2176c67..145fe45c 100644 --- a/gpu4pyscf/lib/cupy_helper/block_diag.cu +++ b/gpu4pyscf/lib/cupy_helper/block_diag.cu @@ -24,7 +24,7 @@ static void _block_diag(double *out, int m, int n, double *diags, int ndiags, in int i = threadIdx.x; int j = threadIdx.y; int r = blockIdx.x; - + if (r >= ndiags){ return; } diff --git a/gpu4pyscf/lib/cupy_helper/take_last2d.cu b/gpu4pyscf/lib/cupy_helper/take_last2d.cu index 36342013..26a1a6a6 100644 --- a/gpu4pyscf/lib/cupy_helper/take_last2d.cu +++ b/gpu4pyscf/lib/cupy_helper/take_last2d.cu @@ -27,7 +27,7 @@ static void _take(double *a, const double *b, int *indices, int n) if (j >= n || k >= n) { return; } - + int j_b = indices[j]; int k_b = indices[k]; int off = i * n * n; diff --git a/gpu4pyscf/lib/cupy_helper/transpose.cu b/gpu4pyscf/lib/cupy_helper/transpose.cu index 78b3dd39..748c83a8 100644 --- a/gpu4pyscf/lib/cupy_helper/transpose.cu +++ b/gpu4pyscf/lib/cupy_helper/transpose.cu @@ -88,11 +88,11 @@ int CPdsymm_triu(double *a, int n, int counts) } __host__ -int transpose_sum(double *a, int n, int counts){ +int transpose_sum(cudaStream_t stream, double *a, int n, int counts){ int ntile = (n + THREADS - 1) / THREADS; dim3 threads(THREADS, THREADS); dim3 blocks(ntile, ntile, counts); - _transpose_sum<<>>(a, n); + _transpose_sum<<>>(a, n); cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { return 1; diff --git a/gpu4pyscf/lib/gdft/contract_rho.cu b/gpu4pyscf/lib/gdft/contract_rho.cu index 1957928e..5c6dbd1c 100644 --- a/gpu4pyscf/lib/gdft/contract_rho.cu +++ b/gpu4pyscf/lib/gdft/contract_rho.cu @@ -30,7 +30,6 @@ void GDFTcontract_rho_kernel(double *rho, double *bra, double *ket, int ngrids, { int grid_id = blockIdx.x * blockDim.x + threadIdx.x; const bool active = grid_id < ngrids; - size_t Ngrids = ngrids; double v = 0; if (active){ diff --git a/gpu4pyscf/lib/gdft/nr_eval_gto.cu b/gpu4pyscf/lib/gdft/nr_eval_gto.cu index 48ff5729..1e2689d7 100644 --- a/gpu4pyscf/lib/gdft/nr_eval_gto.cu +++ b/gpu4pyscf/lib/gdft/nr_eval_gto.cu @@ -86,28 +86,28 @@ void _screen_index(int *non0shl_idx, double cutoff, int l, int ish, int nprim, d template __device__ static void _cart2sph(double g_cart[GTO_MAX_CART], double *g_sph, int stride, int grid_id){ if (ANG == 0) { - g_sph[grid_id + 0*stride] += g_cart[0]; + g_sph[grid_id ] += g_cart[0]; } else if (ANG == 1){ - g_sph[grid_id + 0*stride] += g_cart[0]; - g_sph[grid_id + 1*stride] += g_cart[1]; - g_sph[2*stride] += g_cart[2]; + g_sph[grid_id ] += g_cart[0]; + g_sph[grid_id + stride] += g_cart[1]; + g_sph[grid_id + 2*stride] += g_cart[2]; } else if (ANG == 2){ - g_sph[grid_id + 0*stride] += 1.092548430592079070 * g_cart[1]; - g_sph[grid_id + 1*stride] += 1.092548430592079070 * g_cart[4]; + g_sph[grid_id ] += 1.092548430592079070 * g_cart[1]; + g_sph[grid_id + stride] += 1.092548430592079070 * g_cart[4]; g_sph[grid_id + 2*stride] += 0.630783130505040012 * g_cart[5] - 0.315391565252520002 * (g_cart[0] + g_cart[3]); g_sph[grid_id + 3*stride] += 1.092548430592079070 * g_cart[2]; g_sph[grid_id + 4*stride] += 0.546274215296039535 * (g_cart[0] - g_cart[3]); } else if (ANG == 3){ - g_sph[grid_id + 0*stride] += 1.770130769779930531 * g_cart[1] - 0.590043589926643510 * g_cart[6]; - g_sph[grid_id + 1*stride] += 2.890611442640554055 * g_cart[4]; + g_sph[grid_id ] += 1.770130769779930531 * g_cart[1] - 0.590043589926643510 * g_cart[6]; + g_sph[grid_id + stride] += 2.890611442640554055 * g_cart[4]; g_sph[grid_id + 2*stride] += 1.828183197857862944 * g_cart[8] - 0.457045799464465739 * (g_cart[1] + g_cart[6]); g_sph[grid_id + 3*stride] += 0.746352665180230782 * g_cart[9] - 1.119528997770346170 * (g_cart[2] + g_cart[7]); g_sph[grid_id + 4*stride] += 1.828183197857862944 * g_cart[5] - 0.457045799464465739 * (g_cart[0] + g_cart[3]); g_sph[grid_id + 5*stride] += 1.445305721320277020 * (g_cart[2] - g_cart[7]); g_sph[grid_id + 6*stride] += 0.590043589926643510 * g_cart[0] - 1.770130769779930530 * g_cart[3]; } else if (ANG == 4){ - g_sph[grid_id + 0*stride] += 2.503342941796704538 * (g_cart[1] - g_cart[6]) ; - g_sph[grid_id + 1*stride] += 5.310392309339791593 * g_cart[4] - 1.770130769779930530 * g_cart[11]; + g_sph[grid_id ] += 2.503342941796704538 * (g_cart[1] - g_cart[6]) ; + g_sph[grid_id + stride] += 5.310392309339791593 * g_cart[4] - 1.770130769779930530 * g_cart[11]; g_sph[grid_id + 2*stride] += 5.677048174545360108 * g_cart[8] - 0.946174695757560014 * (g_cart[1] + g_cart[6]); g_sph[grid_id + 3*stride] += 2.676186174229156671 * g_cart[13]- 2.007139630671867500 * (g_cart[4] + g_cart[11]); g_sph[grid_id + 4*stride] += 0.317356640745612911 * (g_cart[0] + g_cart[10]) + 0.634713281491225822 * g_cart[3] - 2.538853125964903290 * (g_cart[5] + g_cart[12]) + 0.846284375321634430 * g_cart[14]; @@ -118,6 +118,92 @@ static void _cart2sph(double g_cart[GTO_MAX_CART], double *g_sph, int stride, in } } +template __device__ +static void _memset_cart(double *g_cart, int stride, int grid_id){ + if (ANG == 0){ + g_cart[grid_id] = 0.0; + } else if (ANG == 1){ + g_cart[grid_id ] = 0.0; + g_cart[grid_id + stride] = 0.0; + g_cart[grid_id + 2*stride] = 0.0; + } else if (ANG == 2){ + g_cart[grid_id ] = 0.0; + g_cart[grid_id + stride] = 0.0; + g_cart[grid_id + 2*stride] = 0.0; + g_cart[grid_id + 3*stride] = 0.0; + g_cart[grid_id + 4*stride] = 0.0; + g_cart[grid_id + 5*stride] = 0.0; + } else if (ANG == 3){ + g_cart[grid_id ] = 0.0; + g_cart[grid_id + stride] = 0.0; + g_cart[grid_id + 2*stride] = 0.0; + g_cart[grid_id + 3*stride] = 0.0; + g_cart[grid_id + 4*stride] = 0.0; + g_cart[grid_id + 5*stride] = 0.0; + g_cart[grid_id + 6*stride] = 0.0; + g_cart[grid_id + 7*stride] = 0.0; + g_cart[grid_id + 8*stride] = 0.0; + g_cart[grid_id + 9*stride] = 0.0; + } else if (ANG == 4){ + g_cart[grid_id ] = 0.0; + g_cart[grid_id + stride] = 0.0; + g_cart[grid_id + 2*stride] = 0.0; + g_cart[grid_id + 3*stride] = 0.0; + g_cart[grid_id + 4*stride] = 0.0; + g_cart[grid_id + 5*stride] = 0.0; + g_cart[grid_id + 6*stride] = 0.0; + g_cart[grid_id + 7*stride] = 0.0; + g_cart[grid_id + 8*stride] = 0.0; + g_cart[grid_id + 9*stride] = 0.0; + g_cart[grid_id +10*stride] = 0.0; + g_cart[grid_id +11*stride] = 0.0; + g_cart[grid_id +12*stride] = 0.0; + g_cart[grid_id +14*stride] = 0.0; + } else { + int i = 0; + for (int lx = ANG; lx >= 0; lx--){ + for (int ly = ANG - lx; ly >= 0; ly--, i++){ + g_cart[grid_id + i*stride] = 0.0; + } + } + } +} + +template __device__ +static void _memset_sph(double *g_sph, int stride, int grid_id){ + if (ANG == 0){ + g_sph[grid_id] = 0.0; + } else if (ANG == 1){ + g_sph[grid_id ] = 0.0; + g_sph[grid_id + stride] = 0.0; + g_sph[grid_id + 2*stride] = 0.0; + } else if (ANG == 2){ + g_sph[grid_id ] = 0.0; + g_sph[grid_id + stride] = 0.0; + g_sph[grid_id + 2*stride] = 0.0; + g_sph[grid_id + 3*stride] = 0.0; + g_sph[grid_id + 4*stride] = 0.0; + } else if (ANG == 3){ + g_sph[grid_id ] = 0.0; + g_sph[grid_id + stride] = 0.0; + g_sph[grid_id + 2*stride] = 0.0; + g_sph[grid_id + 3*stride] = 0.0; + g_sph[grid_id + 4*stride] = 0.0; + g_sph[grid_id + 5*stride] = 0.0; + g_sph[grid_id + 6*stride] = 0.0; + } else if (ANG == 4){ + g_sph[grid_id ] = 0.0; + g_sph[grid_id + stride] = 0.0; + g_sph[grid_id + 2*stride] = 0.0; + g_sph[grid_id + 3*stride] = 0.0; + g_sph[grid_id + 4*stride] = 0.0; + g_sph[grid_id + 5*stride] = 0.0; + g_sph[grid_id + 6*stride] = 0.0; + g_sph[grid_id + 7*stride] = 0.0; + g_sph[grid_id + 8*stride] = 0.0; + } +} + template __device__ static void _cart_gto(double *g, double ce, double *fx, double *fy, double *fz){ for (int lx = ANG, i = 0; lx >= 0; lx--){ @@ -962,46 +1048,52 @@ static void _sph_kernel_deriv1(BasOffsets offsets) */ gto[ grid_id] = 1.092548430592079070 * g1; gto[1*ngrids+grid_id] = 1.092548430592079070 * g4; - gto[2*ngrids+grid_id] = 0.630783130505040012 * g5 - 0.315391565252520002 * (g0 + g3); + gto[2*ngrids+grid_id] = 0.315391565252520002 * (2 * g5 - g0 - g3); gto[3*ngrids+grid_id] = 1.092548430592079070 * g2; gto[4*ngrids+grid_id] = 0.546274215296039535 * (g0 - g3); double ax = ce_2a * rx; - g0 = (ax * rx + 2 * ce) * rx; - g1 = (ax * rx + ce) * ry; - g2 = (ax * rx + ce) * rz; + double ax_ce = ax * rx + ce; + double ax_2ce = ax_ce + ce; + g0 = ax_2ce * rx; + g1 = ax_ce * ry; + g2 = ax_ce * rz; g3 = ax * ry * ry; g4 = ax * ry * rz; g5 = ax * rz * rz; gtox[ grid_id] = 1.092548430592079070 * g1; gtox[1*ngrids+grid_id] = 1.092548430592079070 * g4; - gtox[2*ngrids+grid_id] = 0.630783130505040012 * g5 - 0.315391565252520002 * (g0 + g3); + gtox[2*ngrids+grid_id] = 0.315391565252520002 * (2 * g5 - g0 - g3); gtox[3*ngrids+grid_id] = 1.092548430592079070 * g2; gtox[4*ngrids+grid_id] = 0.546274215296039535 * (g0 - g3); double ay = ce_2a * ry; + double ay_ce = ay * ry + ce; + double ay_2ce = ay_ce + ce; g0 = ay * rx * rx; - g1 = (ay * ry + ce) * rx; + g1 = ay_ce * rx; g2 = ay * rx * rz; - g3 = (ay * ry + 2 * ce) * ry; - g4 = (ay * ry + ce) * rz; + g3 = ay_2ce * ry; + g4 = ay_ce * rz; g5 = ay * rz * rz; gtoy[ grid_id] = 1.092548430592079070 * g1; gtoy[1*ngrids+grid_id] = 1.092548430592079070 * g4; - gtoy[2*ngrids+grid_id] = 0.630783130505040012 * g5 - 0.315391565252520002 * (g0 + g3); + gtoy[2*ngrids+grid_id] = 0.315391565252520002 * (2 * g5 - g0 - g3); gtoy[3*ngrids+grid_id] = 1.092548430592079070 * g2; gtoy[4*ngrids+grid_id] = 0.546274215296039535 * (g0 - g3); double az = ce_2a * rz; + double az_ce = az * rz + ce; + double az_2ce = az_ce + ce; g0 = az * rx * rx; g1 = az * rx * ry; - g2 = (az * rz + ce) * rx; + g2 = az_ce * rx; g3 = az * ry * ry; - g4 = (az * rz + ce) * ry; - g5 = (az * rz + 2 * ce) * rz; + g4 = az_ce * ry; + g5 = az_2ce * rz; gtoz[ grid_id] = 1.092548430592079070 * g1; gtoz[1*ngrids+grid_id] = 1.092548430592079070 * g4; - gtoz[2*ngrids+grid_id] = 0.630783130505040012 * g5 - 0.315391565252520002 * (g0 + g3); + gtoz[2*ngrids+grid_id] = 0.315391565252520002 * (2 * g5 - g0 - g3); gtoz[3*ngrids+grid_id] = 1.092548430592079070 * g2; gtoz[4*ngrids+grid_id] = 0.546274215296039535 * (g0 - g3); } else if (ANG == 3) { @@ -1024,12 +1116,15 @@ static void _sph_kernel_deriv1(BasOffsets offsets) gto[6*ngrids+grid_id] = 0.590043589926643510 * g0 - 1.770130769779930530 * g3; double ax = ce_2a * rx; - g0 = (ax * rx + 3 * ce) * rx * rx; - g1 = (ax * rx + 2 * ce) * rx * ry; - g2 = (ax * rx + 2 * ce) * rx * rz; - g3 = (ax * rx + ce) * ry * ry; - g4 = (ax * rx + ce) * ry * rz; - g5 = (ax * rx + ce) * rz * rz; + double ax_ce = ax * rx + ce; + double ax_2ce = ax_ce + ce; + double ax_3ce = ax_2ce + ce; + g0 = ax_3ce * rx * rx; + g1 = ax_2ce * rx * ry; + g2 = ax_2ce * rx * rz; + g3 = ax_ce * ry * ry; + g4 = ax_ce * ry * rz; + g5 = ax_ce * rz * rz; g6 = ax * ry * ry * ry; g7 = ax * ry * ry * rz; g8 = ax * ry * rz * rz; @@ -1043,16 +1138,19 @@ static void _sph_kernel_deriv1(BasOffsets offsets) gtox[6*ngrids+grid_id] = 0.590043589926643510 * g0 - 1.770130769779930530 * g3; double ay = ce_2a * ry; - g0 = ay * rx * rx * rx; - g1 = (ay * ry + ce) * rx * rx; - g2 = ay * rx * rx * rz; - g3 = (ay * ry + 2 * ce) * rx * ry; - g4 = (ay * ry + ce) * rx * rz; - g5 = ay * rx * rz * rz; - g6 = (ay * ry + 3 * ce) * ry * ry; - g7 = (ay * ry + 2 * ce) * ry * rz; - g8 = (ay * ry + ce) * rz * rz; - g9 = ay * rz * rz * rz; + double ay_ce = ay * ry + ce; + double ay_2ce = ay_ce + ce; + double ay_3ce = ay_2ce + ce; + g0 = ay * rx * rx * rx; + g1 = ay_ce * rx * rx; + g2 = ay * rx * rx * rz; + g3 = ay_2ce * rx * ry; + g4 = ay_ce * rx * rz; + g5 = ay * rx * rz * rz; + g6 = ay_3ce * ry * ry; + g7 = ay_2ce * ry * rz; + g8 = ay_ce * rz * rz; + g9 = ay * rz * rz * rz; gtoy[ grid_id] = 1.770130769779930531 * g1 - 0.590043589926643510 * g6; gtoy[1*ngrids+grid_id] = 2.890611442640554055 * g4; gtoy[2*ngrids+grid_id] = 1.828183197857862944 * g8 - 0.457045799464465739 * (g1 + g6); @@ -1062,16 +1160,19 @@ static void _sph_kernel_deriv1(BasOffsets offsets) gtoy[6*ngrids+grid_id] = 0.590043589926643510 * g0 - 1.770130769779930530 * g3; double az = ce_2a * rz; - g0 = az * rx * rx * rx; - g1 = az * rx * rx * ry; - g2 = (az * rz + ce) * rx * rx; - g3 = az * rx * ry * ry; - g4 = (az * rz + ce) * rx * ry; - g5 = (az * rz + 2 * ce) * rx * rz; - g6 = az * ry * ry * ry; - g7 = (az * rz + ce) * ry * ry; - g8 = (az * rz + 2 * ce) * ry * rz; - g9 = (az * rz + 3 * ce) * rz * rz; + double az_ce = az * rz + ce; + double az_2ce = az_ce + ce; + double az_3ce = az_2ce + ce; + g0 = az * rx * rx * rx; + g1 = az * rx * rx * ry; + g2 = az_ce * rx * rx; + g3 = az * rx * ry * ry; + g4 = az_ce * rx * ry; + g5 = az_2ce * rx * rz; + g6 = az * ry * ry * ry; + g7 = az_ce * ry * ry; + g8 = az_2ce * ry * rz; + g9 = az_3ce * rz * rz; gtoz[ grid_id] = 1.770130769779930531 * g1 - 0.590043589926643510 * g6; gtoz[1*ngrids+grid_id] = 2.890611442640554055 * g4; gtoz[2*ngrids+grid_id] = 1.828183197857862944 * g8 - 0.457045799464465739 * (g1 + g6); From 6cf2ee3b6219bad9c0796cf18ee9c33f086e8e1c Mon Sep 17 00:00:00 2001 From: Qiming Sun Date: Wed, 10 Jan 2024 16:04:47 -0800 Subject: [PATCH 03/10] zero-copy np arrays (#76) * Add empty_mapped for zero-copy * Add takebak * Add test --- gpu4pyscf/lib/cupy_helper.py | 47 +++++++++++++++++++++- gpu4pyscf/lib/cupy_helper/take_last2d.cu | 50 ++++++++++++++++++++++-- gpu4pyscf/lib/tests/test_cupy_helper.py | 13 +++++- 3 files changed, 103 insertions(+), 7 deletions(-) diff --git a/gpu4pyscf/lib/cupy_helper.py b/gpu4pyscf/lib/cupy_helper.py index 3d0d25f5..bbaf2ebc 100644 --- a/gpu4pyscf/lib/cupy_helper.py +++ b/gpu4pyscf/lib/cupy_helper.py @@ -240,12 +240,12 @@ def block_diag(blocks, out=None): def take_last2d(a, indices, out=None): ''' - reorder the last 2 dimensions with 'indices', the first n-2 indices do not change - shape in the last 2 dimensions have to be the same + Reorder the last 2 dimensions as a[..., indices[:,None], indices] ''' assert a.flags.c_contiguous assert a.shape[-1] == a.shape[-2] nao = a.shape[-1] + assert len(indices) == nao if a.ndim == 2: count = 1 else: @@ -266,6 +266,35 @@ def take_last2d(a, indices, out=None): raise RuntimeError('failed in take_last2d kernel') return out +def takebak(out, a, indices, axis=-1): + '''(experimental) + Take elements from a NumPy array along an axis and write to CuPy array. + out[..., indices] = a + ''' + assert axis == -1 + assert isinstance(a, np.ndarray) + assert isinstance(out, cupy.ndarray) + assert out.ndim == a.ndim + assert a.shape[-1] == len(indices) + if a.ndim == 1: + count = 1 + else: + assert out.shape[:-1] == a.shape[:-1] + count = np.prod(a.shape[:-1]) + n_a = a.shape[-1] + n_o = out.shape[-1] + indices_int32 = cupy.asarray(indices, dtype=cupy.int32) + stream = cupy.cuda.get_current_stream() + err = libcupy_helper.takebak( + ctypes.c_void_p(stream.ptr), + ctypes.c_void_p(out.data.ptr), a.ctypes, + ctypes.c_void_p(indices_int32.data.ptr), + ctypes.c_int(count), ctypes.c_int(n_o), ctypes.c_int(n_a) + ) + if err != 0: # Not the mapped host memory + out[...,indices] = cupy.asarray(a) + return out + def transpose_sum(a, stream=None): ''' return a + a.transpose(0,2,1) @@ -497,3 +526,17 @@ def _qr(xs, dot, lindep=1e-14): def _gen_x0(v, xs): return cupy.dot(v.T, xs) + +def empty_mapped(shape, dtype=float, order='C'): + '''(experimental) + Returns a new, uninitialized NumPy array with the given shape and dtype. + + This is a convenience function which is just :func:`numpy.empty`, + except that the underlying buffer is a pinned and mapped memory. + This array can be used as the buffer of zero-copy memory. + ''' + nbytes = np.prod(shape) * np.dtype(dtype).itemsize + mem = cupy.cuda.PinnedMemoryPointer( + cupy.cuda.PinnedMemory(nbytes, cupy.cuda.runtime.hostAllocMapped), 0) + out = np.ndarray(shape, dtype=dtype, buffer=mem, order=order) + return out diff --git a/gpu4pyscf/lib/cupy_helper/take_last2d.cu b/gpu4pyscf/lib/cupy_helper/take_last2d.cu index 26a1a6a6..4b671211 100644 --- a/gpu4pyscf/lib/cupy_helper/take_last2d.cu +++ b/gpu4pyscf/lib/cupy_helper/take_last2d.cu @@ -17,11 +17,12 @@ #include #include #define THREADS 32 +#define COUNT_BLOCK 80 __global__ -static void _take(double *a, const double *b, int *indices, int n) +static void _take_last2d(double *a, const double *b, int *indices, int n) { - int i = blockIdx.z; + size_t i = blockIdx.z; int j = blockIdx.x * blockDim.x + threadIdx.x; int k = blockIdx.y * blockDim.y + threadIdx.y; if (j >= n || k >= n) { @@ -35,6 +36,27 @@ static void _take(double *a, const double *b, int *indices, int n) a[off + j * n + k] = b[off + j_b * n + k_b]; } +__global__ +static void _takebak(double *out, double *a, int *indices, + int count, int n_o, int n_a) +{ + int i0 = blockIdx.y * COUNT_BLOCK; + int j = blockIdx.x * blockDim.x + threadIdx.x; + if (j > n_a) { + return; + } + + // a is on host with zero-copy memory. We need enough iterations for + // data prefetch to hide latency + int i1 = i0 + COUNT_BLOCK; + if (i1 > count) i1 = count; + int jp = indices[j]; +#pragma unroll + for (size_t i = i0; i < i1; ++i) { + out[i * n_o + jp] = a[i * n_a + j]; + } +} + extern "C" { int take_last2d(cudaStream_t stream, double *a, const double *b, int *indices, int blk_size, int n) { @@ -42,11 +64,33 @@ int take_last2d(cudaStream_t stream, double *a, const double *b, int *indices, i int ntile = (n + THREADS - 1) / THREADS; dim3 threads(THREADS, THREADS); dim3 blocks(ntile, ntile, blk_size); - _take<<>>(a, b, indices, n); + _take_last2d<<>>(a, b, indices, n); cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { return 1; } return 0; } + +int takebak(cudaStream_t stream, double *out, double *a_h, int *indices, + int count, int n_o, int n_a) +{ + double *a_d; + cudaError_t err; + err = cudaHostGetDevicePointer(&a_d, a_h, 0); // zero-copy check + if (err != cudaSuccess) { + return 1; + } + + int ntile = (n_a + THREADS*THREADS - 1) / (THREADS*THREADS); + int ncount = (count + COUNT_BLOCK - 1) / COUNT_BLOCK; + dim3 threads(THREADS*THREADS); + dim3 blocks(ntile, ncount); + _takebak<<>>(out, a_d, indices, count, n_o, n_a); + err = cudaGetLastError(); + if (err != cudaSuccess) { + return 1; + } + return 0; +} } diff --git a/gpu4pyscf/lib/tests/test_cupy_helper.py b/gpu4pyscf/lib/tests/test_cupy_helper.py index 614e0fa3..fed01456 100644 --- a/gpu4pyscf/lib/tests/test_cupy_helper.py +++ b/gpu4pyscf/lib/tests/test_cupy_helper.py @@ -18,7 +18,7 @@ import cupy from gpu4pyscf.lib.cupy_helper import ( take_last2d, transpose_sum, krylov, unpack_sparse, - add_sparse) + add_sparse, takebak, empty_mapped) class KnownValues(unittest.TestCase): def test_take_last2d(self): @@ -69,6 +69,15 @@ def test_sparse(self): add_sparse(a, b, indices) assert cupy.linalg.norm(a - a0) < 1e-10 + def test_takebak(self): + a = empty_mapped((5, 8)) + a[:] = 1. + idx = numpy.arange(8) * 2 + out = cupy.zeros((5, 16)) + takebak(out, a, idx) + out[:,idx] -= 1. + assert abs(out).sum() == 0. + if __name__ == "__main__": print("Full tests for cupy helper module") - unittest.main() \ No newline at end of file + unittest.main() From e2b9eded44ea51c94d991f90c209fbacb17b2fa3 Mon Sep 17 00:00:00 2001 From: Zhengxiao Wu <15712521+t0saki@users.noreply.github.com> Date: Sat, 13 Jan 2024 02:57:04 +0800 Subject: [PATCH 04/10] built-in s-dftd3/dftd4 with OpenMP disabled (#79) * change dftd3 source * add to cmdclass * fix build issue * add dftd4 * change internal dftd4 * use bash to run * build using run * use string cmd --- builder/build_dftdx.sh | 54 ++++++++++++++++++++++++++++++++++++++++ gpu4pyscf/dft/rks.py | 4 +-- gpu4pyscf/grad/rks.py | 4 +-- gpu4pyscf/hessian/rks.py | 4 +-- requirements.txt | 2 +- setup.py | 36 +++++++++++++++++++++++++-- 6 files changed, 95 insertions(+), 9 deletions(-) create mode 100644 builder/build_dftdx.sh diff --git a/builder/build_dftdx.sh b/builder/build_dftdx.sh new file mode 100644 index 00000000..05e43e3c --- /dev/null +++ b/builder/build_dftdx.sh @@ -0,0 +1,54 @@ +#!/bin/bash + +WORK_DIR="./tmp" +rm -r ${WORK_DIR} +mkdir -p ${WORK_DIR} + +PROJECT_NAME=${PROJECT_NAME:-"dftd3"} + +SOURCE_URL=${SOURCE_URL:-"https://github.com/dftd3/simple-dftd3/releases/download/v1.0.0/dftd3-1.0.0-sdist.tar.gz"} + +TAR_GZ_NAME=$(basename ${SOURCE_URL}) + +BUILD_DIR="${WORK_DIR}/_build" +INSTALL_DIR="${WORK_DIR}/${PROJECT_NAME}-build" + +pip3 install meson ninja + +cd ${WORK_DIR} + +echo "Downloading source code from $SOURCE_URL..." +curl -L $SOURCE_URL -o $TAR_GZ_NAME + +echo "Extracting $TAR_GZ_NAME..." +tar -xzf $TAR_GZ_NAME + +SOURCE_DIR=$(tar -tf $TAR_GZ_NAME | head -1 | cut -f1 -d"/") +cd $SOURCE_DIR + +echo " +option( + 'openmp', + type: 'boolean', + value: false, + yield: true, + description: 'Use OpenMP parallelisation', +)" >> meson_options.txt + +echo "Setting up build system with meson..." +meson setup --wipe $BUILD_DIR -Dopenmp=false + +echo "Compiling the code..." +meson compile -C $BUILD_DIR + +echo "Configuring build system with prefix..." +meson configure $BUILD_DIR --prefix=$(realpath ${INSTALL_DIR}) + +echo "Installing to $INSTALL_DIR..." +meson install -C $BUILD_DIR + +echo "Installation complete." + +cd ../../ + +echo "All operations completed." \ No newline at end of file diff --git a/gpu4pyscf/dft/rks.py b/gpu4pyscf/dft/rks.py index 7b9dd588..9af80ab2 100644 --- a/gpu4pyscf/dft/rks.py +++ b/gpu4pyscf/dft/rks.py @@ -250,7 +250,7 @@ def get_dispersion(self): # multi-threads in DFTD3 conflicts with PyTorch, set it to be 1 for safty from pyscf import lib with lib.with_omp_threads(1): - import dftd3.pyscf as disp + import gpu4pyscf.dftd3.pyscf as disp d3 = disp.DFTD3Dispersion(self.mol, xc=self.xc, version=self.disp) e_d3, _ = d3.kernel() return e_d3 @@ -261,7 +261,7 @@ def get_dispersion(self): coords = self.mol.atom_coords() from pyscf import lib with lib.with_omp_threads(1): - from dftd4.interface import DampingParam, DispersionModel + from gpu4pyscf.dftd4.interface import DampingParam, DispersionModel model = DispersionModel(atoms, coords) res = model.get_dispersion(DampingParam(method=self.xc), grad=False) return res.get("energy") diff --git a/gpu4pyscf/grad/rks.py b/gpu4pyscf/grad/rks.py index 82b0befd..4b0b5191 100644 --- a/gpu4pyscf/grad/rks.py +++ b/gpu4pyscf/grad/rks.py @@ -504,7 +504,7 @@ def get_dispersion(self): if self.base.disp[:2].upper() == 'D3': from pyscf import lib with lib.with_omp_threads(1): - import dftd3.pyscf as disp + import gpu4pyscf.dftd3.pyscf as disp d3 = disp.DFTD3Dispersion(self.mol, xc=self.base.xc, version=self.base.disp) _, g_d3 = d3.kernel() return g_d3 @@ -516,7 +516,7 @@ def get_dispersion(self): from pyscf import lib with lib.with_omp_threads(1): - from dftd4.interface import DampingParam, DispersionModel + from gpu4pyscf.dftd4.interface import DampingParam, DispersionModel model = DispersionModel(atoms, coords) res = model.get_dispersion(DampingParam(method=self.base.xc), grad=True) return res.get("gradient") \ No newline at end of file diff --git a/gpu4pyscf/hessian/rks.py b/gpu4pyscf/hessian/rks.py index 2082e707..43f97c7b 100644 --- a/gpu4pyscf/hessian/rks.py +++ b/gpu4pyscf/hessian/rks.py @@ -681,7 +681,7 @@ def get_dispersion(self): if self.base.disp[:2].upper() == 'D3': from pyscf import lib with lib.with_omp_threads(1): - import dftd3.pyscf as disp + import gpu4pyscf.dftd3.pyscf as disp coords = self.mol.atom_coords() natm = self.mol.natm h_d3 = numpy.zeros([self.mol.natm, self.mol.natm, 3,3]) @@ -710,7 +710,7 @@ def get_dispersion(self): natm = self.mol.natm from pyscf import lib with lib.with_omp_threads(1): - from dftd4.interface import DampingParam, DispersionModel + from gpu4pyscf.dftd4.interface import DampingParam, DispersionModel params = DampingParam(method=self.base.xc) mol = self.mol.copy() h_d3 = numpy.zeros([self.mol.natm, self.mol.natm, 3,3]) diff --git a/requirements.txt b/requirements.txt index 73a2c439..be4060b3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ cutensor==1.6.0.3 cutensor-cu11==1.6.1 cutensornet-cu11==2.0.0 Cython==0.29.32 -dftd3==0.7.0 +# dftd3==0.7.0 geometric==1.0 h5py==3.7.0 numpy==1.23.5 diff --git a/setup.py b/setup.py index 57372a31..3c842953 100755 --- a/setup.py +++ b/setup.py @@ -21,6 +21,8 @@ import sys import subprocess import re +import glob +import subprocess from setuptools import setup, find_packages, Extension from setuptools.command.build_py import build_py @@ -84,8 +86,34 @@ def run(self): self.announce(' '.join(cmd)) else: self.spawn(cmd) + + self.build_dftd('dftd3', 'https://github.com/dftd3/simple-dftd3/releases/download/v1.0.0/dftd3-1.0.0-sdist.tar.gz') + self.build_dftd('dftd4', 'https://github.com/dftd4/dftd4/releases/download/v3.6.0/dftd4-sdist-3.6.0.tar.gz') + super().run() + def build_dftd(self,project_name,source_url): + self.plat_name = get_platform() + self.build_base = 'build' + self.build_lib = os.path.join(self.build_base, 'lib') + self.build_temp = os.path.join(self.build_base, f'temp.{self.plat_name}') + + script_path = 'builder/build_dftdx.sh' + if not os.path.exists(script_path): + raise FileNotFoundError("Cannot find build script: {}".format(script_path)) + + subprocess.run(f"PROJECT_NAME={project_name} SOURCE_URL={source_url} sh {script_path}", shell=True, check=True) + + build_dir_pattern = f'tmp/{project_name}-*/tmp/{project_name}-build/lib/python3/dist-packages/{project_name}' + build_dirs = glob.glob(build_dir_pattern) + if not len(build_dirs) == 1: + raise FileNotFoundError("Cannot find build directory: {}".format(build_dir_pattern)) + build_dir = build_dirs[0] + + target_dir = os.path.join(self.build_lib, 'gpu4pyscf', project_name) + self.copy_tree(build_dir, target_dir) + + # build_py will produce plat_name = 'any'. Patch the bdist_wheel to change the # platform tag because the C extensions are platform dependent. from wheel.bdist_wheel import bdist_wheel @@ -124,9 +152,13 @@ def initialize_with_default_plat_name(self): install_requires=[ 'pyscf>=2.4.0', f'cupy-cuda{CUDA_VERSION}>=12.0', - 'dftd3==0.7.0', - 'dftd4==3.5.0', + # 'dftd3==0.7.0', + # 'dftd4==3.5.0', 'geometric', f'gpu4pyscf-libxc-cuda{CUDA_VERSION}', ], + package_data={ + "gpu4pyscf.dftd3": ["_libdftd3*.so", "parameters.toml"], + "gpu4pyscf.dftd4": ["_libdftd4*.so", "*.toml", "*.json"], + }, ) From 19dc057ce4c5eaea0b9fe221ba6edcc9b9d9bb00 Mon Sep 17 00:00:00 2001 From: Xiaojie Wu Date: Fri, 12 Jan 2024 10:57:28 -0800 Subject: [PATCH 05/10] V0.6.15 (#80) * fixed a bug in screen_index * added unit test for to_gpu * new grids group scheme * use grid_aligned in gpu4pyscf.__config__ * fixed a bug in eval_ao * fixed a bug in transpose_sum * remove print * new get_vxc scheme * added temp examples * short example * optimized initialization * updated version --- examples/dft_driver.py | 1 + gpu4pyscf/__config__.py | 2 +- gpu4pyscf/__init__.py | 2 +- gpu4pyscf/df/df.py | 42 ++---- gpu4pyscf/df/df_jk.py | 5 +- gpu4pyscf/df/hessian/rhf.py | 14 +- gpu4pyscf/df/int3c2e.py | 46 +++--- gpu4pyscf/df/tests/test_int3c2e.py | 26 ++++ gpu4pyscf/dft/gen_grid.py | 8 +- gpu4pyscf/dft/numint.py | 183 ++++++++++------------- gpu4pyscf/dft/rks.py | 1 + gpu4pyscf/grad/rhf.py | 46 +++++- gpu4pyscf/grad/rks.py | 32 ++-- gpu4pyscf/hessian/rhf.py | 78 ++++++---- gpu4pyscf/hessian/rks.py | 85 ++++++----- gpu4pyscf/lib/cupy_helper.py | 45 ++++-- gpu4pyscf/lib/cupy_helper/CMakeLists.txt | 3 +- gpu4pyscf/lib/cupy_helper/dist_matrix.cu | 49 ++++++ gpu4pyscf/lib/gdft/gen_grids.cu | 4 +- gpu4pyscf/lib/gdft/nr_eval_gto.cu | 12 +- gpu4pyscf/lib/gint/g3c2e.cu | 8 +- gpu4pyscf/lib/gint/gout3c2e.cu | 50 +++++-- gpu4pyscf/lib/gint/nr_fill_ao_int3c2e.cu | 16 +- gpu4pyscf/lib/gvhf/g3c2e.cuh | 60 ++++---- gpu4pyscf/lib/gvhf/g3c2e_pass1.cu | 12 +- gpu4pyscf/lib/gvhf/g3c2e_pass1_root1.cu | 11 +- gpu4pyscf/lib/gvhf/g3c2e_pass2.cu | 10 +- gpu4pyscf/lib/gvhf/g3c2e_pass2_root1.cu | 13 +- gpu4pyscf/lib/logger.py | 8 + gpu4pyscf/lib/tests/test_cupy_helper.py | 9 +- gpu4pyscf/solvent/grad/pcm.py | 17 ++- gpu4pyscf/solvent/hessian/pcm.py | 5 +- gpu4pyscf/solvent/pcm.py | 6 +- 33 files changed, 574 insertions(+), 335 deletions(-) create mode 100644 gpu4pyscf/lib/cupy_helper/dist_matrix.cu diff --git a/examples/dft_driver.py b/examples/dft_driver.py index 80830e6c..d7479292 100644 --- a/examples/dft_driver.py +++ b/examples/dft_driver.py @@ -27,6 +27,7 @@ parser.add_argument("--solvent", type=str, default='') args = parser.parse_args() +lib.num_threads(16) start_time = time.time() bas = args.basis mol = pyscf.M( diff --git a/gpu4pyscf/__config__.py b/gpu4pyscf/__config__.py index 5b740207..93346e36 100644 --- a/gpu4pyscf/__config__.py +++ b/gpu4pyscf/__config__.py @@ -5,7 +5,7 @@ # such as A100-80G if props['totalGlobalMem'] >= 64 * GB: min_ao_blksize = 256 - min_grid_blksize = 256*256 + min_grid_blksize = 128*128 ao_aligned = 32 grid_aligned = 128 mem_fraction = 0.9 diff --git a/gpu4pyscf/__init__.py b/gpu4pyscf/__init__.py index 5fd45d2f..12273fed 100644 --- a/gpu4pyscf/__init__.py +++ b/gpu4pyscf/__init__.py @@ -1,5 +1,5 @@ from . import lib, grad, hessian, solvent, scf, dft -__version__ = '0.6.14' +__version__ = '0.6.15' # monkey patch libxc reference due to a bug in nvcc from pyscf.dft import libxc diff --git a/gpu4pyscf/df/df.py b/gpu4pyscf/df/df.py index ff3c6877..14294843 100644 --- a/gpu4pyscf/df/df.py +++ b/gpu4pyscf/df/df.py @@ -57,18 +57,6 @@ def build(self, direct_scf_tol=1e-14, omega=None): auxmol = self.auxmol self.nao = mol.nao - # cache indices for better performance - nao = mol.nao - tril_row, tril_col = cupy.tril_indices(nao) - tril_row = cupy.asarray(tril_row) - tril_col = cupy.asarray(tril_col) - - self.tril_row = tril_row - self.tril_col = tril_col - - idx = np.arange(nao) - self.diag_idx = cupy.asarray(idx*(idx+1)//2+idx) - log = logger.new_logger(mol, mol.verbose) t0 = log.init_timer() if auxmol is None: @@ -147,7 +135,7 @@ def loop(self, blksize=None, unpack=True): rows = self.intopt.cderi_row cols = self.intopt.cderi_col buf_prefetch = None - + buf_cderi = cupy.zeros([blksize,nao,nao]) data_stream = cupy.cuda.stream.Stream(non_blocking=True) compute_stream = cupy.cuda.get_current_stream() #compute_stream = cupy.cuda.stream.Stream() @@ -165,14 +153,15 @@ def loop(self, blksize=None, unpack=True): buf_prefetch.set(cderi_sparse[p1:p2,:]) stop_event = data_stream.record() if unpack: - buf2 = cupy.zeros([p1-p0,nao,nao]) - buf2[:p1-p0,rows,cols] = buf - buf2[:p1-p0,cols,rows] = buf + buf_cderi[:p1-p0,rows,cols] = buf + buf_cderi[:p1-p0,cols,rows] = buf + buf2 = buf_cderi[:p1-p0] else: buf2 = None yield buf2, buf.T compute_stream.wait_event(stop_event) - cupy.cuda.Device().synchronize() + if isinstance(cderi_sparse, np.ndarray): + cupy.cuda.Device().synchronize() if buf_prefetch is not None: buf = buf_prefetch @@ -217,8 +206,8 @@ def cholesky_eri_gpu(intopt, mol, auxmol, cd_low, omega=None, sr_only=False): cderi = np.ndarray([naux, npair], dtype=np.float64, order='C', buffer=mem) except Exception: raise RuntimeError('Out of CPU memory') - - data_stream = cupy.cuda.stream.Stream(non_blocking=False) + if(not use_gpu_memory): + data_stream = cupy.cuda.stream.Stream(non_blocking=False) count = 0 nq = len(intopt.log_qs) for cp_ij_id, _ in enumerate(intopt.log_qs): @@ -234,20 +223,20 @@ def cholesky_eri_gpu(intopt, mol, auxmol, cd_low, omega=None, sr_only=False): nj = j1 - j0 if sr_only: # TODO: in-place implementation or short-range kernel - ints_slices = cupy.zeros([naoaux, nj, ni], order='C') + ints_slices = cupy.empty([naoaux, nj, ni], order='C') for cp_kl_id, _ in enumerate(intopt.aux_log_qs): k0 = intopt.sph_aux_loc[cp_kl_id] k1 = intopt.sph_aux_loc[cp_kl_id+1] int3c2e.get_int3c2e_slice(intopt, cp_ij_id, cp_kl_id, out=ints_slices[k0:k1]) if omega is not None: - ints_slices_lr = cupy.zeros([naoaux, nj, ni], order='C') + ints_slices_lr = cupy.empty([naoaux, nj, ni], order='C') for cp_kl_id, _ in enumerate(intopt.aux_log_qs): k0 = intopt.sph_aux_loc[cp_kl_id] k1 = intopt.sph_aux_loc[cp_kl_id+1] int3c2e.get_int3c2e_slice(intopt, cp_ij_id, cp_kl_id, out=ints_slices[k0:k1], omega=omega) ints_slices -= ints_slices_lr else: - ints_slices = cupy.zeros([naoaux, nj, ni], order='C') + ints_slices = cupy.empty([naoaux, nj, ni], order='C') for cp_kl_id, _ in enumerate(intopt.aux_log_qs): k0 = intopt.sph_aux_loc[cp_kl_id] k1 = intopt.sph_aux_loc[cp_kl_id+1] @@ -261,11 +250,8 @@ def cholesky_eri_gpu(intopt, mol, auxmol, cd_low, omega=None, sr_only=False): row = intopt.ao_pairs_row[cp_ij_id] - i0 col = intopt.ao_pairs_col[cp_ij_id] - j0 - if cpi == cpj: - #ints_slices = ints_slices + ints_slices.transpose([0,2,1]) - transpose_sum(ints_slices) - ints_slices = ints_slices[:,col,row] + ints_slices = ints_slices[:,col,row] if cd_low.tag == 'eig': cderi_block = cupy.dot(cd_low.T, ints_slices) ints_slices = None @@ -281,8 +267,8 @@ def cholesky_eri_gpu(intopt, mol, auxmol, cd_low, omega=None, sr_only=False): for i in range(naux): cderi_block[i].get(out=cderi[i,ij0:ij1]) t1 = log.timer_debug1(f'solve {cp_ij_id} / {nq}', *t1) - - cupy.cuda.Device().synchronize() + if not use_gpu_memory: + cupy.cuda.Device().synchronize() return cderi diff --git a/gpu4pyscf/df/df_jk.py b/gpu4pyscf/df/df_jk.py index ff0cbd7e..72c01e1d 100644 --- a/gpu4pyscf/df/df_jk.py +++ b/gpu4pyscf/df/df_jk.py @@ -63,7 +63,6 @@ def build_df(): rks.initialize_grids(mf, mf.mol, dm0) ni.build(mf.mol, mf.grids.coords) mf._numint.xcfuns = numint._init_xcfuns(mf.xc, dm0.ndim==3) - dm0 = cupy.asarray(dm0) return def _density_fit(mf, auxbasis=None, with_df=None, only_dfj=False): @@ -253,10 +252,11 @@ def get_jk(dfobj, dms_tag, hermi=1, with_j=True, with_k=True, direct_scf_tol=1e- nao = dms_tag.shape[-1] dms = dms_tag.reshape([-1,nao,nao]) nset = dms.shape[0] - t0 = log.init_timer() + t1 = t0 = log.init_timer() if dfobj._cderi is None: log.debug('CDERI not found, build...') dfobj.build(direct_scf_tol=direct_scf_tol, omega=omega) + t1 = log.timer_debug1('init jk', *t0) assert nao == dfobj.nao vj = None @@ -264,7 +264,6 @@ def get_jk(dfobj, dms_tag, hermi=1, with_j=True, with_k=True, direct_scf_tol=1e- ao_idx = dfobj.intopt.sph_ao_idx dms = take_last2d(dms, ao_idx) - t1 = log.timer_debug1('init jk', *t0) rows = dfobj.intopt.cderi_row cols = dfobj.intopt.cderi_col if with_j: diff --git a/gpu4pyscf/df/hessian/rhf.py b/gpu4pyscf/df/hessian/rhf.py index ff3d0ef2..a6e793f6 100644 --- a/gpu4pyscf/df/hessian/rhf.py +++ b/gpu4pyscf/df/hessian/rhf.py @@ -37,6 +37,7 @@ import cupy import numpy as np from pyscf import lib, df +from gpu4pyscf.grad import rhf as rhf_grad from gpu4pyscf.hessian import rhf as rhf_hess from gpu4pyscf.lib.cupy_helper import contract, tag_array, release_gpu_stack, print_mem_info, take_last2d from gpu4pyscf.df import int3c2e @@ -283,6 +284,8 @@ def _partial_hess_ejk(hessobj, mo_energy=None, mo_coeff=None, mo_occ=None, hk_aux_aux += .5 * contract('pqxy,pq->pqxy', rho2c_11, int2c_inv) # (00|1)(1|00) rho2c_0 = rho2c_10 = rho2c_11 = rho2c0_10 = rho2c1_10 = rho2c0_11 = int2c_ip_ip = None wk_ip2_P__ = int2c_ip1_inv = None + t1 = log.timer_debug1('contract int2c_*', *t1) + ao_idx = np.argsort(intopt.sph_ao_idx) aux_idx = np.argsort(intopt.sph_aux_idx) rev_ao_ao = cupy.ix_(ao_idx, ao_idx) @@ -372,7 +375,7 @@ def _partial_hess_ejk(hessobj, mo_energy=None, mo_coeff=None, mo_occ=None, e1[j0,i0] = e1[i0,j0].T ej[j0,i0] = ej[i0,j0].T ek[j0,i0] = ek[i0,j0].T - + t1 = log.timer_debug1('hcore contribution', *t1) log.timer('RHF partial hessian', *time0) return e1, ej, ek @@ -398,6 +401,7 @@ def make_h1(hessobj, mo_coeff, mo_occ, chkfile=None, atmlst=None, verbose=None): else: return chkfile ''' + def _gen_jk(hessobj, mo_coeff, mo_occ, chkfile=None, atmlst=None, verbose=None, with_k=True, omega=None): log = logger.new_logger(hessobj, verbose) @@ -521,8 +525,9 @@ def _ao2mo(mat): vk1_int3c = vk1_int3c_ip1 + vk1_int3c_ip2 vk1_int3c_ip1 = vk1_int3c_ip2 = None + grad_hcore = rhf_grad.get_grad_hcore(hessobj.base.nuc_grad_method()) cupy.get_default_memory_pool().free_all_blocks() - hcore_deriv = hessobj.base.nuc_grad_method().hcore_generator(mol) + #hcore_deriv = hessobj.base.nuc_grad_method().hcore_generator(mol) vk1 = None for i0, ia in enumerate(atmlst): shl0, shl1, p0, p1 = aoslices[ia] @@ -535,8 +540,9 @@ def _ao2mo(mat): vk1_ao[:,p0:p1,:] -= vk1_buf[:,p0:p1,:] vk1_ao[:,:,p0:p1] -= vk1_buf[:,p0:p1,:].transpose(0,2,1) - h1 = hcore_deriv(ia) - h1 = _ao2mo(cupy.asarray(h1, order='C')) + h1 = grad_hcore[:,i0] + #h1 = hcore_deriv(ia) + #h1 = _ao2mo(cupy.asarray(h1, order='C')) vj1 = vj1_int3c[ia] + _ao2mo(vj1_ao) if with_k: vk1 = vk1_int3c[ia] + _ao2mo(vk1_ao) diff --git a/gpu4pyscf/df/int3c2e.py b/gpu4pyscf/df/int3c2e.py index 56668354..e45f11dc 100644 --- a/gpu4pyscf/df/int3c2e.py +++ b/gpu4pyscf/df/int3c2e.py @@ -306,10 +306,7 @@ def build(self, cutoff=1e-14, group_size=None, ncptype = len(log_qs) self.bpcache = ctypes.POINTER(BasisProdCache)() - if diag_block_with_triu: - scale_shellpair_diag = 1. - else: - scale_shellpair_diag = 0.5 + scale_shellpair_diag = 1. libgint.GINTinit_basis_prod( ctypes.byref(self.bpcache), ctypes.c_double(scale_shellpair_diag), ao_loc.ctypes.data_as(ctypes.c_void_p), @@ -1194,6 +1191,32 @@ def get_dh1e(mol, dm0): dh1e[k0:k1,:3] += cupy.einsum('xkji,ij->kx', int3c_blk, dm0_sorted[i0:i1,j0:j1]) return 2.0 * cupy.einsum('kx,k->kx', dh1e, -charges) +def get_d2h1e(mol, dm0): + natm = mol.natm + coords = mol.atom_coords() + charges = mol.atom_charges() + fakemol = gto.fakemol_for_charges(coords) + + nao = mol.nao + d2h1e_diag = cupy.zeros([natm,9]) + d2h1e_offdiag = cupy.zeros([natm, nao, 9]) + intopt = VHFOpt(mol, fakemol, 'int2e') + intopt.build(1e-14, diag_block_with_triu=True, aosym=False, group_size=BLKSIZE, group_size_aux=BLKSIZE) + dm0_sorted = take_last2d(dm0, intopt.sph_ao_idx) + for i0,i1,j0,j1,k0,k1,int3c_blk in loop_int3c2e_general(intopt, ip_type='ipip1'): + d2h1e_diag[k0:k1,:9] -= contract('xaji,ij->ax', int3c_blk, dm0_sorted[i0:i1,j0:j1]) + d2h1e_offdiag[k0:k1,i0:i1,:9] += contract('xaji,ij->aix', int3c_blk, dm0_sorted[i0:i1,j0:j1]) + + for i0,i1,j0,j1,k0,k1,int3c_blk in loop_int3c2e_general(intopt, ip_type='ipvip1'): + d2h1e_diag[k0:k1,:9] -= contract('xaji,ij->ax', int3c_blk, dm0_sorted[i0:i1,j0:j1]) + d2h1e_offdiag[k0:k1,i0:i1,:9] += contract('xaji,ij->aix', int3c_blk, dm0_sorted[i0:i1,j0:j1]) + aoslices = mol.aoslice_by_atom() + ao2atom = get_ao2atom(intopt, aoslices) + d2h1e = contract('aix,ib->abx', d2h1e_offdiag, ao2atom) + d2h1e[np.diag_indices(natm), :] += d2h1e_diag + return 2.0 * cupy.einsum('abx,a->xab', d2h1e, charges) + #return 2.0 * cupy.einsum('ijx,i->kx', dh1e, -charges) + def get_int3c2e_slice(intopt, cp_ij_id, cp_aux_id, aosym=None, out=None, omega=None, stream=None): ''' Generate one int3c2e block for given ij, k @@ -1443,14 +1466,6 @@ def get_pairing(p_offsets, q_offsets, q_cond, for q0, q1 in zip(q_offsets[:-1], q_offsets[1:]): if aosym and q0 < p0 or not aosym: q_sub = q_cond[p0:p1,q0:q1].ravel() - ''' - idx = q_sub.argsort(axis=None)[::-1] - q_sorted = q_sub[idx] - mask = q_sorted > cutoff - idx = idx[mask] - ishs, jshs = np.unravel_index(idx, (p1-p0, q1-q0)) - print(ishs.shape) - ''' mask = q_sub > cutoff ishs, jshs = np.indices((p1-p0,q1-q0)) ishs = ishs.ravel()[mask] @@ -1464,13 +1479,6 @@ def get_pairing(p_offsets, q_offsets, q_cond, log_qs.append(log_q) elif aosym and p0 == q0 and p1 == q1: q_sub = q_cond[p0:p1,p0:p1].ravel() - ''' - idx = q_sub.argsort(axis=None)[::-1] - q_sorted = q_sub[idx] - ishs, jshs = np.unravel_index(idx, (p1-p0, p1-p0)) - mask = q_sorted > cutoff - ''' - ishs, jshs = np.indices((p1-p0, p1-p0)) ishs = ishs.ravel() jshs = jshs.ravel() diff --git a/gpu4pyscf/df/tests/test_int3c2e.py b/gpu4pyscf/df/tests/test_int3c2e.py index 8e02ac3d..6c556061 100644 --- a/gpu4pyscf/df/tests/test_int3c2e.py +++ b/gpu4pyscf/df/tests/test_int3c2e.py @@ -116,6 +116,32 @@ def test_int1e_iprinv(self): h1ao = mol.intor('int1e_iprinv', comp=3) # <\nabla|1/r|> assert np.linalg.norm(int3c[:,:,:,i] - h1ao) < 1e-8 + def test_int1e_ipiprinv(self): + from pyscf import gto + coords = mol.atom_coords() + charges = mol.atom_charges() + + fakemol = gto.fakemol_for_charges(coords) + int3c = int3c2e.get_int3c2e_general(mol, fakemol, ip_type='ipip1').get() + + for i,q in enumerate(charges): + mol.set_rinv_origin(coords[i]) + h1ao = mol.intor('int1e_ipiprinv', comp=9) # <\nabla|1/r|> + assert np.linalg.norm(int3c[:,:,:,i] - h1ao) < 1e-8 + + def test_int1e_iprinvip(self): + from pyscf import gto + coords = mol.atom_coords() + charges = mol.atom_charges() + + fakemol = gto.fakemol_for_charges(coords) + int3c = int3c2e.get_int3c2e_general(mol, fakemol, ip_type='ipvip1').get() + + for i,q in enumerate(charges): + mol.set_rinv_origin(coords[i]) + h1ao = mol.intor('int1e_iprinvip', comp=9) # <\nabla|1/r|> + assert np.linalg.norm(int3c[:,:,:,i] - h1ao) < 1e-8 + if __name__ == "__main__": print("Full Tests for int3c") unittest.main() diff --git a/gpu4pyscf/dft/gen_grid.py b/gpu4pyscf/dft/gen_grid.py index 5ced5d94..5528e1e5 100644 --- a/gpu4pyscf/dft/gen_grid.py +++ b/gpu4pyscf/dft/gen_grid.py @@ -279,14 +279,16 @@ def get_partition(mol, atom_grids_tab, grid_coord and grid_weight arrays. grid_coord array has shape (N,3); weight 1D array has N elements. ''' + atm_coords = numpy.asarray(mol.atom_coords() , order='C') + atm_coords = cupy.asarray(atm_coords) + ''' if callable(radii_adjust) and atomic_radii is not None: f_radii_adjust = radii_adjust(mol, atomic_radii) else: f_radii_adjust = None - atm_coords = numpy.asarray(mol.atom_coords() , order='C') atm_dist = gto.inter_distance(mol) - atm_coords = cupy.asarray(atm_coords) atm_dist = cupy.asarray(atm_dist) + if (becke_scheme is original_becke and (radii_adjust is radi.treutler_atomic_radii_adjust or radii_adjust is radi.becke_atomic_radii_adjust or @@ -324,7 +326,7 @@ def gen_grid_partition(coords): pbecke[i] *= .5 * (1-g) pbecke[j] *= .5 * (1+g) return pbecke - + ''' coords_all = [] weights_all = [] # support atomic_radii_adjust = None diff --git a/gpu4pyscf/dft/numint.py b/gpu4pyscf/dft/numint.py index 0e46c611..f3ecda69 100644 --- a/gpu4pyscf/dft/numint.py +++ b/gpu4pyscf/dft/numint.py @@ -31,12 +31,12 @@ from gpu4pyscf.lib import logger LMAX_ON_GPU = 6 -BAS_ALIGNED = 4 +BAS_ALIGNED = 1 GRID_BLKSIZE = 32 MIN_BLK_SIZE = getattr(__config__, 'min_grid_blksize', 64*64) ALIGNED = getattr(__config__, 'grid_aligned', 16*16) AO_ALIGNMENT = getattr(__config__, 'ao_aligned', 16) -AO_THRESHOLD = 1e-12 +AO_THRESHOLD = 1e-10 # Should we release the cupy cache? FREE_CUPY_CACHE = False @@ -199,14 +199,15 @@ def eval_rho2(mol, ao, mo_coeff, mo_occ, non0tab=None, xctype='LDA', rho = _contract_rho(c0, c0) elif xctype in ('GGA', 'NLC'): rho = cupy.empty((4,ngrids)) - c0 = _dot_ao_dm(mol, ao[0], cpos, non0tab, shls_slice, ao_loc) + #c0 = _dot_ao_dm(mol, ao[0], cpos, non0tab, shls_slice, ao_loc) + c0 = contract('nig,io->nog', ao, cpos) #:rho[0] = numpy.einsum('pi,pi->p', c0, c0) - _contract_rho(c0, c0, rho=rho[0]) + _contract_rho(c0[0], c0[0], rho=rho[0]) for i in range(1, 4): - c1 = _dot_ao_dm(mol, ao[i], cpos, non0tab, shls_slice, ao_loc) + #c1 = _dot_ao_dm(mol, ao[i], cpos, non0tab, shls_slice, ao_loc) #:rho[i] = numpy.einsum('pi,pi->p', c0, c1) * 2 # *2 for +c.c. - _contract_rho(c0, c1, rho=rho[i]) - rho[i] *= 2 + _contract_rho(c0[0], c0[i], rho=rho[i]) + rho[1:] *= 2 else: # meta-GGA if with_lapl: # rho[4] = \nabla^2 rho, rho[5] = 1/2 |nabla f|^2 @@ -215,17 +216,18 @@ def eval_rho2(mol, ao, mo_coeff, mo_occ, non0tab=None, xctype='LDA', else: rho = cupy.empty((5,ngrids)) tau_idx = 4 - c0 = _dot_ao_dm(mol, ao[0], cpos, non0tab, shls_slice, ao_loc) + #c0 = _dot_ao_dm(mol, ao[0], cpos, non0tab, shls_slice, ao_loc) + c0 = contract('nig,io->nog', ao, cpos) #:rho[0] = numpy.einsum('pi,pi->p', c0, c0) - _contract_rho(c0, c0, rho=rho[0]) + _contract_rho(c0[0], c0[0], rho=rho[0]) rho[tau_idx] = 0 for i in range(1, 4): - c1 = _dot_ao_dm(mol, ao[i], cpos, non0tab, shls_slice, ao_loc) + #c1 = _dot_ao_dm(mol, ao[i], cpos, non0tab, shls_slice, ao_loc) #:rho[i] = numpy.einsum('pi,pi->p', c0, c1) * 2 # *2 for +c.c. #:rho[5] += numpy.einsum('pi,pi->p', c1, c1) - rho[i] = _contract_rho(c0, c1) * 2 - rho[tau_idx] += _contract_rho(c1, c1) + rho[i] = _contract_rho(c0[0], c0[i]) + rho[tau_idx] += _contract_rho(c0[i], c0[i]) if with_lapl: if ao.shape[0] > 4: @@ -233,11 +235,12 @@ def eval_rho2(mol, ao, mo_coeff, mo_occ, non0tab=None, xctype='LDA', ao2 = ao[XX] + ao[YY] + ao[ZZ] c1 = _dot_ao_dm(mol, ao2, cpos, non0tab, shls_slice, ao_loc) #:rho[4] = numpy.einsum('pi,pi->p', c0, c1) - rho[4] = _contract_rho(c0, c1) + rho[4] = _contract_rho(c0[0], c1) rho[4] += rho[5] rho[4] *= 2 else: rho[4] = 0 + rho[1:4] *= 2 rho[tau_idx] *= .5 return rho @@ -321,11 +324,14 @@ def eval_rho4(mol, ao, c0, mo1, non0tab=None, xctype='LDA', rho[i] = _contract_rho(c0, c_0[i]) rho *= 2.0 elif xctype in ('GGA', 'NLC'): + log = logger.new_logger(mol, mol.verbose) + t0 = log.init_timer() c_0 = contract('nig,aio->anog', ao, cpos1) + t0 = log.timer_debug1('ao * cpos', *t0) rho = cupy.empty([na, 4, ngrids]) for i in range(na): _contract_rho_gga(c0, c_0[i], rho=rho[i]) - + t0 = log.timer_debug1('contract rho', *t0) else: # meta-GGA if with_lapl: raise NotImplementedError("mGGA with lapl not implemented") @@ -448,11 +454,13 @@ def nr_rks(ni, mol, grids, xc_code, dms, relativity=0, hermi=1, nao, nao0 = coeff.shape dms = cupy.asarray(dms) dm_shape = dms.shape - dms = [coeff @ dm @ coeff.T for dm in dms.reshape(-1,nao0,nao0)] + #dms = [coeff @ dm @ coeff.T for dm in dms.reshape(-1,nao0,nao0)] + dms = dms.reshape(-1,nao0,nao0) + dms = take_last2d(dms, opt.ao_idx) nset = len(dms) if mo_coeff is not None: - mo_coeff = coeff @ mo_coeff + mo_coeff = mo_coeff[opt.ao_idx] nelec = cupy.zeros(nset) excsum = cupy.zeros(nset) @@ -482,19 +490,20 @@ def nr_rks(ni, mol, grids, xc_code, dms, relativity=0, hermi=1, rho_tot = cupy.empty([nset,6,ngrids]) p0 = p1 = 0 + t1 = t0 = log.init_timer() for ao_mask, idx, weight, _ in ni.block_loop(mol, grids, nao, ao_deriv): p1 = p0 + weight.size for i in range(nset): - t0 = log.init_timer() if mo_coeff is None: rho_tot[i,:,p0:p1] = eval_rho(mol, ao_mask, dms[i][np.ix_(idx,idx)], xctype=xctype, hermi=1) else: mo_coeff_mask = mo_coeff[idx,:] rho_tot[i,:,p0:p1] = eval_rho2(mol, ao_mask, mo_coeff_mask, mo_occ, None, xctype) - t1 = log.timer_debug1('eval rho', *t0) p0 = p1 + t1 = log.timer_debug2('eval rho slice', *t1) + t0 = log.timer_debug1('eval rho', *t0) - vxc_tot = [] + wv = [] for i in range(nset): if xctype == 'LDA': exc, vxc = ni.eval_xc_eff(xc_code, rho_tot[i][0], deriv=1, xctype=xctype)[:2] @@ -505,69 +514,37 @@ def nr_rks(ni, mol, grids, xc_code, dms, relativity=0, hermi=1, den = rho_tot[i][0] * grids.weights nelec[i] = den.sum() excsum[i] = cupy.sum(den * exc[:,0]) - vxc_tot.append(vxc) - t1 = log.timer_debug1('eval vxc', *t1) + wv.append(vxc * grids.weights) + if xctype == 'GGA': + wv[i][0] *= .5 + if xctype == 'MGGA': + wv[i][[0,4]] *= .5 + t0 = log.timer_debug1('eval vxc', *t0) + t1 = t0 p0 = p1 = 0 for ao_mask, idx, weight, _ in ni.block_loop(mol, grids, nao, ao_deriv): p1 = p0 + weight.size for i in range(nset): - vxc = vxc_tot[i][:,p0:p1] if xctype == 'LDA': - #den = rho * weight - wv = weight * vxc[0] - ''' - if USE_SPARSITY == 0: - vmat[i] += ao.dot(_scale_ao(ao, wv).T) - elif USE_SPARSITY == 1: - _dot_ao_ao_sparse(ao, ao, wv, nbins, sindex, ao_loc, - pair2shls_full, pairs_locs_full, vmat[i]) - ''' if USE_SPARSITY == 2: - aow = _scale_ao(ao_mask, wv) - # vmat[i][cupy.ix_(mask, mask)] += ao_mask.dot(aow.T) + aow = _scale_ao(ao_mask, wv[i][0,p0:p1]) add_sparse(vmat[i], ao_mask.dot(aow.T), idx) else: raise NotImplementedError(f'USE_SPARSITY = {USE_SPARSITY} is not implemented') elif xctype == 'GGA': - #den = rho[0] * weight - wv = vxc * weight - wv[0] *= .5 - ''' - if USE_SPARSITY == 0: - vmat[i] += ao[0].dot(_scale_ao(ao, wv).T) - elif USE_SPARSITY == 1: - aow = _scale_ao(ao, wv) - _dot_ao_ao_sparse(ao[0], aow, None, nbins, sindex, ao_loc, - pair2shls_full, pairs_locs_full, vmat[i]) - ''' if USE_SPARSITY == 2: - aow = _scale_ao(ao_mask, wv) + aow = _scale_ao(ao_mask, wv[i][:,p0:p1]) add_sparse(vmat[i], ao_mask[0].dot(aow.T), idx) else: raise NotImplementedError(f'USE_SPARSITY = {USE_SPARSITY} is not implemented') elif xctype == 'NLC': raise NotImplementedError('NLC') elif xctype == 'MGGA': - #den = rho[0] * weight - wv = vxc * weight - wv[[0, 4]] *= .5 # *.5 for v+v.T - ''' - if USE_SPARSITY == 0: - aow = _scale_ao(ao[:4], wv[:4]) - vmat[i] += ao[0].dot(aow.T) - vmat[i] += _tau_dot(ao, ao, wv[4]) - elif USE_SPARSITY == 1: - _dot_ao_ao_sparse(ao[0], aow, None, nbins, sindex, ao_loc, - pair2shls_full, pairs_locs_full, vmat[i]) - _tau_dot_sparse(ao, ao, wv[4], nbins, sindex, ao_loc, - pair2shls_full, pairs_locs_full, vmat[i]) - ''' if USE_SPARSITY == 2: - aow = _scale_ao(ao_mask, wv[:4]) + aow = _scale_ao(ao_mask, wv[i][:4,p0:p1]) vtmp = ao_mask[0].dot(aow.T) - vtmp+= _tau_dot(ao_mask, ao_mask, wv[4]) - #vmat[i][cupy.ix_(mask, mask)] += vtmp + vtmp+= _tau_dot(ao_mask, ao_mask, wv[i][4,p0:p1]) add_sparse(vmat[i], vtmp, idx) else: raise NotImplementedError(f'USE_SPARSITY = {USE_SPARSITY} is not implemented') @@ -575,17 +552,14 @@ def nr_rks(ni, mol, grids, xc_code, dms, relativity=0, hermi=1, pass else: raise NotImplementedError(f'numint.nr_rks for functional {xc_code}') - t1 = log.timer_debug1('integration', *t1) p0 = p1 - - vmat = contract('pi,npq->niq', coeff, vmat) - vmat = contract('qj,niq->nij', coeff, vmat) - #rev_ao_idx = opt.rev_ao_idx - #vmat = take_last2d(vmat, rev_ao_idx) + t1 = log.timer_debug2('integration', *t1) + t0 = log.timer_debug1('vxc integration', *t0) + rev_ao_idx = opt.rev_ao_idx + vmat = take_last2d(vmat, rev_ao_idx) if xctype != 'LDA': - vmat = vmat + vmat.transpose([0,2,1]) - #transpose_sum(vmat) + transpose_sum(vmat) if FREE_CUPY_CACHE: dms = None @@ -719,27 +693,27 @@ def get_rho(ni, mol, dm, grids, max_memory=2000, verbose=None): log = logger.new_logger(mol, verbose) coeff = cupy.asarray(opt.coeff) nao = coeff.shape[0] - dm = coeff @ cupy.asarray(dm) @ coeff.T - mo_coeff = getattr(dm, 'mo_coeff', None) mo_occ = getattr(dm,'mo_occ', None) + dm = coeff @ cupy.asarray(dm) @ coeff.T if mo_coeff is not None: mo_coeff = coeff @ mo_coeff ngrids = grids.weights.size rho = cupy.empty(ngrids) p0 = p1 = 0 + t1 = t0 = log.init_timer() for ao_mask, idx, weight, _ in ni.block_loop(mol, grids, nao, 0): p1 = p0 + weight.size - t0 = log.init_timer() if mo_coeff is None: rho[p0:p1] = eval_rho(mol, ao_mask, dm[np.ix_(idx,idx)], xctype='LDA', hermi=1) else: mo_coeff_mask = mo_coeff[idx,:] rho[p0:p1] = eval_rho2(mol, ao_mask, mo_coeff_mask, mo_occ, None, 'LDA') p0 = p1 - log.timer_debug1('eval rho', *t0) + t1 = log.timer_debug2('eval rho slice', *t1) + t0 = log.timer_debug1('eval rho', *t0) if FREE_CUPY_CACHE: dm = None @@ -763,10 +737,9 @@ def nr_rks_fxc(ni, mol, grids, xc_code, dm0=None, dms=None, relativity=0, hermi= # AO basis -> gdftopt AO basis with_mocc = hasattr(dms, 'mo1') if with_mocc: - mo1 = contract('nio,pi->npo', dms.mo1, coeff) * 2.0**0.5 - occ_coeff = contract('io,pi->po', dms.occ_coeff, coeff) * 2.0**0.5 - dms = contract('nij,qj->niq', dms, coeff) - dms = contract('pi,niq->npq', coeff, dms) + mo1 = dms.mo1[:,opt.ao_idx] * 2.0**0.5 + occ_coeff = dms.occ_coeff[opt.ao_idx] * 2.0**0.5 + dms = take_last2d(dms, opt.ao_idx) nset = len(dms) vmat = cupy.zeros((nset, nao, nao)) @@ -776,8 +749,8 @@ def nr_rks_fxc(ni, mol, grids, xc_code, dm0=None, dms=None, relativity=0, hermi= ao_deriv = 1 p0 = 0 p1 = 0 + t1 = t0 = log.init_timer() for ao, mask, weights, coords in ni.block_loop(opt.mol, grids, nao, ao_deriv): - t0 = log.init_timer() p0, p1 = p1, p1+len(weights) # precompute molecular orbitals if with_mocc: @@ -788,7 +761,7 @@ def nr_rks_fxc(ni, mol, grids, xc_code, dm0=None, dms=None, relativity=0, hermi= c0 = contract('nig,io->nog', ao, occ_coeff_mask) else: # mgga c0 = contract('nig,io->nog', ao, occ_coeff_mask) - + t1 = log.timer_debug2(f'eval occ_coeff, with mocc: {with_mocc}', *t1) if with_mocc: rho1 = eval_rho4(opt.mol, ao, c0, mo1[:,mask], xctype=xctype, with_lapl=False) else: @@ -798,7 +771,7 @@ def nr_rks_fxc(ni, mol, grids, xc_code, dm0=None, dms=None, relativity=0, hermi= rho_tmp = eval_rho(opt.mol, ao, dms[i][np.ix_(mask,mask)], xctype=xctype, hermi=hermi, with_lapl=False) rho1.append(rho_tmp) rho1 = cupy.stack(rho1, axis=0) - t0 = log.timer_debug1('rho', *t0) + t1 = log.timer_debug2('eval rho', *t1) # precompute fxc_w if xctype == 'LDA': @@ -826,14 +799,14 @@ def nr_rks_fxc(ni, mol, grids, xc_code, dm0=None, dms=None, relativity=0, hermi= vmat_tmp+= _tau_dot(ao, ao, wv[i,4]) add_sparse(vmat[i], vmat_tmp, mask) - t0 = log.timer_debug1('vxc', *t0) + t1 = log.timer_debug2('integration', *t1) ao = c0 = rho1 = None + t0 = log.timer_debug1('vxc', *t0) - vmat = contract('pi,npq->niq', coeff, vmat) - vmat = contract('qj,niq->nij', coeff, vmat) + vmat = take_last2d(vmat, opt.rev_ao_idx) if xctype != 'LDA': - #transpose_sum(vmat) - vmat = vmat + vmat.transpose([0,2,1]) + transpose_sum(vmat) + if FREE_CUPY_CACHE: dms = None cupy.get_default_memory_pool().free_all_blocks() @@ -1067,13 +1040,14 @@ def nr_nlc_vxc(ni, mol, grids, xc_code, dms, relativity=0, hermi=1, wv = vv_vxc[:,p0:p1] * weight wv[0] *= .5 aow = _scale_ao(ao, wv) - #vmat += ao[0].dot(aow.T) add_sparse(vmat, ao[0].dot(aow.T), mask) t1 = log.timer_debug1('integration', *t1) - vmat = vmat + vmat.T - vmat = contract('pi,pq->iq', coeff, vmat) - vmat = contract('qj,iq->ij', coeff, vmat) + transpose_sum(vmat) + vmat = take_last2d(vmat, opt.rev_ao_idx) + #vmat = vmat + vmat.T + #vmat = contract('pi,pq->iq', coeff, vmat) + #vmat = contract('qj,iq->ij', coeff, vmat) log.timer_debug1('eval vv10', *t0) return nelec, excsum, vmat @@ -1101,28 +1075,31 @@ def cache_xc_kernel(ni, mol, grids, xc_code, mo_coeff, mo_occ, spin=0, if spin == 0: mo_coeff = coeff @ mo_coeff rho = [] - t0 = log.init_timer() + t1 = t0 = log.init_timer() for ao_mask, idx, weight, _ in ni.block_loop(mol, grids, nao, ao_deriv): mo_coeff_mask = mo_coeff[idx,:] rho_slice = eval_rho2(mol, ao_mask, mo_coeff_mask, mo_occ, None, xctype) rho.append(rho_slice) - t0 = log.timer_debug1('eval rho in fxc', *t0) + t1 = log.timer_debug2('eval rho slice', *t1) rho = cupy.hstack(rho) + t0 = log.timer_debug1('eval rho in fxc', *t0) else: mo_coeff = contract('ip,npj->nij', coeff, cupy.asarray(mo_coeff)) rhoa = [] rhob = [] - t0 = log.init_timer() + t1 = t0 = log.init_timer() for ao_mask, idx, weight, _ in ni.block_loop(mol, grids, nao, ao_deriv): mo_coeff_mask = mo_coeff[:,idx,:] rhoa_slice = eval_rho2(mol, ao_mask, mo_coeff_mask[0], mo_occ[0], None, xctype) rhob_slice = eval_rho2(mol, ao_mask, mo_coeff_mask[1], mo_occ[1], None, xctype) rhoa.append(rhoa_slice) rhob.append(rhob_slice) - t0 = log.timer_debug1('eval rho in fxc', *t0) + t1 = log.timer_debug2('eval rho in fxc', *t1) #rho = (cupy.hstack(rhoa), cupy.hstack(rhob)) rho = cupy.stack([cupy.hstack(rhoa), cupy.hstack(rhob)], axis=0) + t0 = log.timer_debug1('eval rho in fxc', *t0) vxc, fxc = ni.eval_xc_eff(xc_code, rho, deriv=2, xctype=xctype)[1:3] + t0 = log.timer_debug1('eval fxc', *t0) return rho, vxc, fxc @cupy.fuse() @@ -1267,18 +1244,18 @@ def _block_loop(ni, mol, grids, nao=None, deriv=0, max_memory=2000, mol = opt.mol with opt.gdft_envs_cache(): block_id = 0 + t1 = log.init_timer() for ip0, ip1 in lib.prange(0, ngrids, blksize): coords = grids.coords[ip0:ip1] weight = grids.weights[ip0:ip1] - + t1 = log.init_timer() # cache ao indices - if (deriv, block_id, blksize, ngrids) not in ni.non0ao_idx: + if (block_id, blksize, ngrids) not in ni.non0ao_idx: stream = cupy.cuda.get_current_stream() cutoff = AO_THRESHOLD ng = ip1 - ip0 ao_loc = mol.ao_loc_nr() nbas = mol.nbas - t0 = log.init_timer() non0shl_idx = cupy.zeros(len(ao_loc)-1, dtype=np.int32) libgdft.GDFTscreen_index( ctypes.cast(stream.ptr, ctypes.c_void_p), @@ -1317,24 +1294,25 @@ def _block_loop(ni, mol, grids, nao=None, deriv=0, max_memory=2000, idx = cupy.hstack([idx, zero_idx[:pad]]) pad = min(pad, len(zero_idx)) non0shl_idx = cupy.asarray(np.where(non0shl_idx)[0], dtype=np.int32) - ni.non0ao_idx[deriv, block_id, blksize, ngrids] = (pad, idx, non0shl_idx, ctr_offsets_slice, ao_loc_slice) - log.timer_debug1('init ao sparsity', *t0) + ni.non0ao_idx[block_id, blksize, ngrids] = (pad, idx, non0shl_idx, ctr_offsets_slice, ao_loc_slice) + t1 = log.timer_debug2('init ao sparsity', *t1) else: - pad, idx, non0shl_idx, ctr_offsets_slice, ao_loc_slice = ni.non0ao_idx[deriv, block_id, blksize, ngrids] - t0 = log.init_timer() + pad, idx, non0shl_idx, ctr_offsets_slice, ao_loc_slice = ni.non0ao_idx[block_id, blksize, ngrids] + ao_mask = eval_ao( ni, mol, coords, deriv, nao_slice=len(idx), shls_slice=non0shl_idx, ao_loc_slice=ao_loc_slice, ctr_offsets_slice=ctr_offsets_slice) + + t1 = log.timer_debug2('evaluate ao slice', *t1) if pad > 0: if deriv == 0: ao_mask[-pad:,:] = 0.0 else: ao_mask[:,-pad:,:] = 0.0 block_id += 1 - log.timer_debug1('evaluate ao slice', *t0) yield ao_mask, idx, weight, coords class NumInt(numint.NumInt): @@ -1665,6 +1643,7 @@ def build(self, mol=None): pmol._decontracted = True self.mol = pmol inv_idx = np.argsort(ao_idx, kind='stable').astype(np.int32) + self.ao_idx = cupy.asarray(ao_idx, dtype=np.int32) self.rev_ao_idx = cupy.asarray(inv_idx, dtype=np.int32) self.coeff = coeff[ao_idx] self.l_ctr_offsets = np.append(0, np.cumsum(l_ctr_counts)).astype(np.int32) diff --git a/gpu4pyscf/dft/rks.py b/gpu4pyscf/dft/rks.py index 9af80ab2..34ae52dd 100644 --- a/gpu4pyscf/dft/rks.py +++ b/gpu4pyscf/dft/rks.py @@ -57,6 +57,7 @@ def prune_small_rho_grids_(ks, mol, dm, grids): grids.coords = cupy.vstack( [grids.coords, pad]) grids.weights = cupy.hstack([grids.weights, cupy.zeros(padding)]) + # make_mask has to be executed on cpu for now. #grids.non0tab = grids.make_mask(mol, grids.coords) #grids.screen_index = grids.non0tab diff --git a/gpu4pyscf/grad/rhf.py b/gpu4pyscf/grad/rhf.py index c96820b6..9f1f9efc 100644 --- a/gpu4pyscf/grad/rhf.py +++ b/gpu4pyscf/grad/rhf.py @@ -22,7 +22,7 @@ from pyscf.grad import rhf from gpu4pyscf.lib.cupy_helper import load_library from gpu4pyscf.scf.hf import _VHFOpt -from gpu4pyscf.lib.cupy_helper import tag_array, contract +from gpu4pyscf.lib.cupy_helper import tag_array, contract, take_last2d from gpu4pyscf.df import int3c2e #TODO: move int3c2e to out of df from gpu4pyscf.lib import logger @@ -552,6 +552,50 @@ def calculate_h1e(h1_gpu, s1_gpu): de -= cupy.sum(de, axis=0)/len(atmlst) return de.get() +def get_grad_hcore(mf_grad): + mf = mf_grad.base + mol = mf.mol + natm = mol.natm + nao = mol.nao + mo_occ = mf.mo_occ + mo_coeff = cupy.asarray(mf.mo_coeff) + orbo = mo_coeff[:,mo_occ>0] + nocc = orbo.shape[1] + + dh1e = cupy.zeros([3,natm,nao,nocc]) + coords = mol.atom_coords() + charges = cupy.asarray(mol.atom_charges(), dtype=np.float64) + fakemol = gto.fakemol_for_charges(coords) + intopt = int3c2e.VHFOpt(mol, fakemol, 'int2e') + intopt.build(1e-14, diag_block_with_triu=True, aosym=False, group_size=int3c2e.BLKSIZE, group_size_aux=int3c2e.BLKSIZE) + orbo_sorted = orbo[intopt.sph_ao_idx] + mo_coeff_sorted = mo_coeff[intopt.sph_ao_idx] + for i0,i1,j0,j1,k0,k1,int3c_blk in int3c2e.loop_int3c2e_general(intopt, ip_type='ip1'): + dh1e[:,k0:k1,j0:j1,:] += contract('xkji,io->xkjo', int3c_blk, orbo_sorted[i0:i1]) + dh1e[:,k0:k1,i0:i1,:] += contract('xkji,jo->xkio', int3c_blk, orbo_sorted[j0:j1]) + dh1e = contract('xkjo,k->xkjo', dh1e, -charges) + dh1e = contract('xkjo,jp->xkpo', dh1e, mo_coeff_sorted) + + h1 = mf_grad.get_hcore(mol) + aoslices = mol.aoslice_by_atom() + with_ecp = mol.has_ecp() + if with_ecp: + ecp_atoms = set(mol._ecpbas[:,gto.ATOM_OF]) + else: + ecp_atoms = () + for atm_id in range(natm): + shl0, shl1, p0, p1 = aoslices[atm_id] + h1ao = numpy.zeros([3,nao,nao]) + with mol.with_rinv_at_nucleus(atm_id): + if with_ecp and atm_id in ecp_atoms: + h1ao += mol.intor('ECPscalar_iprinv', comp=3) + h1ao[:,p0:p1] += h1[:,p0:p1] + h1ao += h1ao.transpose([0,2,1]) + h1ao = cupy.asarray(h1ao) + h1mo = contract('xij,jo->xio', h1ao, orbo) + dh1e[:,atm_id] += contract('xio,ip->xpo', h1mo, mo_coeff) + return dh1e#2.0 * cupy.einsum('kx,k->kx', dh1e, -charges) + class Gradients(rhf.Gradients): from gpu4pyscf.lib.utils import to_cpu, to_gpu, device diff --git a/gpu4pyscf/grad/rks.py b/gpu4pyscf/grad/rks.py index 4b0b5191..2bb22df6 100644 --- a/gpu4pyscf/grad/rks.py +++ b/gpu4pyscf/grad/rks.py @@ -27,7 +27,8 @@ from gpu4pyscf.grad import rhf as rhf_grad from gpu4pyscf.dft import numint, xc_deriv, rks from gpu4pyscf.dft.numint import _GDFTOpt, AO_THRESHOLD -from gpu4pyscf.lib.cupy_helper import contract, get_avail_mem, add_sparse, tag_array, load_library +from gpu4pyscf.lib.cupy_helper import ( + contract, get_avail_mem, add_sparse, tag_array, load_library, take_last2d) from gpu4pyscf.lib import logger from pyscf import __config__ @@ -121,10 +122,12 @@ def get_vxc(ni, mol, grids, xc_code, dms, relativity=0, hermi=1, mo_coeff = cupy.asarray(dms.mo_coeff) coeff = cupy.asarray(opt.coeff) nao, nao0 = coeff.shape - dms = cupy.asarray(dms) - dms = [cupy.einsum('pi,ij,qj->pq', coeff, dm, coeff) - for dm in dms.reshape(-1,nao0,nao0)] - mo_coeff = coeff @ mo_coeff + dms = cupy.asarray(dms).reshape(-1,nao0,nao0) + dms = take_last2d(dms, opt.ao_idx) + #dms = [cupy.einsum('pi,ij,qj->pq', coeff, dm, coeff) + # for dm in dms.reshape(-1,nao0,nao0)] + #mo_coeff = coeff @ mo_coeff + mo_coeff = mo_coeff[opt.ao_idx] nset = len(dms) assert nset == 1 @@ -172,7 +175,8 @@ def get_vxc(ni, mol, grids, xc_code, dms, relativity=0, hermi=1, vtmp = _gga_grad_sum_(ao_mask, wv) vtmp += _tau_grad_dot_(ao_mask, wv[4]) add_sparse(vmat[idm], vtmp, idx) - vmat = [cupy.einsum('pi,npq,qj->nij', coeff, v, coeff) for v in vmat] + #vmat = [cupy.einsum('pi,npq,qj->nij', coeff, v, coeff) for v in vmat] + vmat = take_last2d(vmat, opt.rev_ao_idx) exc = None if nset == 1: vmat = vmat[0] @@ -229,8 +233,10 @@ def get_nlc_vxc(ni, mol, grids, xc_code, dms, relativity=0, hermi=1, vmat_tmp = _gga_grad_sum_(ao_mask, wv) add_sparse(vmat, vmat_tmp, mask) - vmat = contract('npq,qj->npj', vmat, coeff) - vmat = contract('pi,npj->nij', coeff, vmat) + #vmat = contract('npq,qj->npj', vmat, coeff) + #vmat = contract('pi,npj->nij', coeff, vmat) + rev_ao_idx = opt.rev_ao_idx + vmat = take_last2d(vmat, rev_ao_idx) exc = None # - sign because nabla_X = -nabla_x return exc, -vmat @@ -274,10 +280,12 @@ def _make_dR_dao_w(ao, wv): def _d1_dot_(ao1, ao2, out=None): if out is None: - vmat0 = cupy.dot(ao1[0], ao2) - vmat1 = cupy.dot(ao1[1], ao2) - vmat2 = cupy.dot(ao1[2], ao2) - return cupy.stack([vmat0,vmat1,vmat2]) + out = cupy.empty([3, ao1[0].shape[0], ao2.shape[1]]) + out[0] = cupy.dot(ao1[0], ao2) + out[1] = cupy.dot(ao1[1], ao2) + out[2] = cupy.dot(ao1[2], ao2) + return out + #return cupy.stack([vmat0,vmat1,vmat2]) else: cupy.dot(ao1[0], ao2, out=out[0]) cupy.dot(ao1[1], ao2, out=out[1]) diff --git a/gpu4pyscf/hessian/rhf.py b/gpu4pyscf/hessian/rhf.py index b6421424..5efb8f12 100644 --- a/gpu4pyscf/hessian/rhf.py +++ b/gpu4pyscf/hessian/rhf.py @@ -57,6 +57,7 @@ def hess_elec(hessobj, mo_energy=None, mo_coeff=None, mo_occ=None, mo_coeff = cupy.asarray(mo_coeff) de2 = hessobj.partial_hess_elec(mo_energy, mo_coeff, mo_occ, atmlst, max_memory, log) + t1 = log.timer_debug1('hess elec', *t1) if h1ao is None: h1ao = hessobj.make_h1(mo_coeff, mo_occ, hessobj.chkfile, atmlst, log) t1 = log.timer_debug1('making H1', *t1) @@ -503,11 +504,13 @@ def hcore_generator(self, mol=None): else: ecp_atoms = () aoslices = mol.aoslice_by_atom() - nbas = mol.nbas nao = mol.nao_nr() h1aa, h1ab = self.get_hcore(mol) h1aa = cupy.asarray(h1aa) h1ab = cupy.asarray(h1ab) + + rinv2aa_all = {} + rinv2ab_all = {} def get_hcore(iatm, jatm): ish0, ish1, i0, i1 = aoslices[iatm] jsh0, jsh1, j0, j1 = aoslices[jatm] @@ -515,15 +518,20 @@ def get_hcore(iatm, jatm): zj = mol.atom_charge(jatm) if iatm == jatm: with mol.with_rinv_at_nucleus(iatm): - rinv2aa = mol.intor('int1e_ipiprinv', comp=9) - rinv2ab = mol.intor('int1e_iprinvip', comp=9) - rinv2aa *= zi - rinv2ab *= zi - if with_ecp and iatm in ecp_atoms: - rinv2aa -= mol.intor('ECPscalar_ipiprinv', comp=9) - rinv2ab -= mol.intor('ECPscalar_iprinvip', comp=9) - rinv2aa = cupy.asarray(rinv2aa) - rinv2ab = cupy.asarray(rinv2ab) + if iatm not in rinv2aa_all: + rinv2aa = zi * mol.intor('int1e_ipiprinv', comp=9) + if with_ecp and iatm in ecp_atoms: + rinv2aa -= mol.intor('ECPscalar_ipiprinv', comp=9) + rinv2aa_all[iatm] = rinv2aa + rinv2aa = cupy.asarray(rinv2aa_all[iatm]) + + if iatm not in rinv2ab_all: + rinv2ab = zi * mol.intor('int1e_iprinvip', comp=9) + if with_ecp and iatm in ecp_atoms: + rinv2ab -= mol.intor('ECPscalar_iprinvip', comp=9) + rinv2ab_all[iatm] = rinv2ab + rinv2ab = cupy.asarray(rinv2ab_all[iatm]) + rinv2aa = rinv2aa.reshape(3,3,nao,nao) rinv2ab = rinv2ab.reshape(3,3,nao,nao) hcore = -rinv2aa - rinv2ab @@ -538,28 +546,42 @@ def get_hcore(iatm, jatm): hcore = cupy.zeros((3,3,nao,nao)) hcore[:,:,i0:i1,j0:j1] += h1ab[:,:,i0:i1,j0:j1] with mol.with_rinv_at_nucleus(iatm): - shls_slice = (jsh0, jsh1, 0, nbas) - rinv2aa = mol.intor('int1e_ipiprinv', comp=9, shls_slice=shls_slice) - rinv2ab = mol.intor('int1e_iprinvip', comp=9, shls_slice=shls_slice) - rinv2aa *= zi - rinv2ab *= zi - if with_ecp and iatm in ecp_atoms: - rinv2aa -= mol.intor('ECPscalar_ipiprinv', comp=9, shls_slice=shls_slice) - rinv2ab -= mol.intor('ECPscalar_iprinvip', comp=9, shls_slice=shls_slice) - rinv2aa = cupy.asarray(rinv2aa) - rinv2ab = cupy.asarray(rinv2ab) + #rinv2aa = mol.intor('int1e_ipiprinv', comp=9, shls_slice=shls_slice) + if iatm not in rinv2aa_all: + rinv2aa = zi * mol.intor('int1e_ipiprinv', comp=9) + if with_ecp and iatm in ecp_atoms: + rinv2aa -= mol.intor('ECPscalar_ipiprinv', comp=9) + rinv2aa_all[iatm] = rinv2aa + rinv2aa = cupy.asarray(rinv2aa_all[iatm][:,j0:j1]) + + #rinv2ab = mol.intor('int1e_iprinvip', comp=9, shls_slice=shls_slice) + if iatm not in rinv2ab_all: + rinv2ab = zi * mol.intor('int1e_iprinvip', comp=9) + if with_ecp and iatm in ecp_atoms: + rinv2ab -= mol.intor('ECPscalar_iprinvip', comp=9) + rinv2ab_all[iatm] = rinv2ab + rinv2ab = cupy.asarray(rinv2ab_all[iatm][:,j0:j1]) + hcore[:,:,j0:j1] += rinv2aa.reshape(3,3,j1-j0,nao) hcore[:,:,j0:j1] += rinv2ab.reshape(3,3,j1-j0,nao).transpose(1,0,2,3) with mol.with_rinv_at_nucleus(jatm): - shls_slice = (ish0, ish1, 0, nbas) - rinv2aa = mol.intor('int1e_ipiprinv', comp=9, shls_slice=shls_slice) - rinv2ab = mol.intor('int1e_iprinvip', comp=9, shls_slice=shls_slice) - rinv2aa *= zj - rinv2ab *= zj - if with_ecp and jatm in ecp_atoms: - rinv2aa -= mol.intor('ECPscalar_ipiprinv', comp=9, shls_slice=shls_slice) - rinv2ab -= mol.intor('ECPscalar_iprinvip', comp=9, shls_slice=shls_slice) + # rinv2aa = zj * mol.intor('int1e_ipiprinv', comp=9, shls_slice=shls_slice) + if jatm not in rinv2aa_all: + rinv2aa = zj * mol.intor('int1e_ipiprinv', comp=9) + if with_ecp and jatm in ecp_atoms: + rinv2aa -= mol.intor('ECPscalar_ipiprinv', comp=9) + rinv2aa_all[jatm] = rinv2aa + rinv2aa = cupy.asarray(rinv2aa_all[jatm][:,i0:i1]) + + # rinv2ab = mol.intor('int1e_iprinvip', comp=9, shls_slice=shls_slice) + if jatm not in rinv2ab_all: + rinv2ab = zj * mol.intor('int1e_iprinvip', comp=9) + if with_ecp and jatm in ecp_atoms: + rinv2ab -= mol.intor('ECPscalar_iprinvip', comp=9) + rinv2ab_all[jatm] = rinv2ab + rinv2ab = cupy.asarray(rinv2ab_all[jatm][:,i0:i1]) + rinv2aa = cupy.asarray(rinv2aa) rinv2ab = cupy.asarray(rinv2ab) hcore[:,:,i0:i1] += rinv2aa.reshape(3,3,i1-i0,nao) diff --git a/gpu4pyscf/hessian/rks.py b/gpu4pyscf/hessian/rks.py index 43f97c7b..f808ecde 100644 --- a/gpu4pyscf/hessian/rks.py +++ b/gpu4pyscf/hessian/rks.py @@ -27,7 +27,7 @@ from gpu4pyscf.hessian import rhf as rhf_hess from gpu4pyscf.grad import rks as rks_grad from gpu4pyscf.dft import numint -from gpu4pyscf.lib.cupy_helper import contract, add_sparse +from gpu4pyscf.lib.cupy_helper import contract, add_sparse, take_last2d from gpu4pyscf.lib import logger # import pyscf.grad.rks to activate nuc_grad_method method @@ -354,12 +354,13 @@ def _d1d2_dot_(vmat, mol, ao1, ao2, mask, ao_loc, dR1_on_bra=True): for d2 in range(3): vmat[d1,d2] += numint._dot_ao_ao(mol, ao1[d1], ao2[d2], mask, shls_slice, ao_loc) + #vmat += contract('xig,yjg->xyij', ao1, ao2) else: # (d/dR2 bra) * (d/dR1 ket) for d1 in range(3): for d2 in range(3): vmat[d1,d2] += numint._dot_ao_ao(mol, ao1[d2], ao2[d1], mask, shls_slice, ao_loc) - + #vmat += contract('yig,xjg->xyij', ao1, ao2) def _get_vxc_deriv2(hessobj, mo_coeff, mo_occ, max_memory): mol = hessobj.mol mf = hessobj.base @@ -393,15 +394,15 @@ def _get_vxc_deriv2(hessobj, mo_coeff, mo_occ, max_memory): ipip = cupy.zeros((3,3,nao,nao)) if xctype == 'LDA': ao_deriv = 1 + t1 = log.init_timer() for ao_mask, mask, weight, coords \ in ni.block_loop(opt.mol, grids, nao, ao_deriv, max_memory): - t0 = log.init_timer() nao_non0 = len(mask) ao = contract('nip,ij->njp', ao_mask, coeff[mask]) rho = numint.eval_rho2(opt.mol, ao[0], mo_coeff, mo_occ, mask, xctype) - t0 = log.timer_debug1('eval rho', *t0) + t1 = log.timer_debug2('eval rho', *t1) vxc, fxc = ni.eval_xc_eff(mf.xc, rho, 2, xctype=xctype)[1:3] - t0 = log.timer_debug1('eval vxc', *t0) + t1 = log.timer_debug2('eval vxc', *t1) wv = weight * vxc[0] aow = [numint._scale_ao(ao[i], wv) for i in range(1, 4)] _d1d2_dot_(ipip, mol, aow, ao[1:4], mask, ao_loc, False) @@ -417,27 +418,28 @@ def _get_vxc_deriv2(hessobj, mo_coeff, mo_occ, max_memory): vmat_tmp = cupy.zeros([3,3,nao_non0,nao_non0]) aow = [numint._scale_ao(ao_mask[0], wv[i]) for i in range(3)] _d1d2_dot_(vmat_tmp, mol, ao_mask[1:4], aow, mask, ao_loc, False) - vmat_tmp = contract('pi,xypq->xyiq', coeff[mask], vmat_tmp) - vmat_tmp = contract('qj,xyiq->xyij', coeff[mask], vmat_tmp) - vmat[ia] += vmat_tmp + #vmat_tmp = contract('pi,xypq->xyiq', coeff[mask], vmat_tmp) + #vmat_tmp = contract('qj,xyiq->xyij', coeff[mask], vmat_tmp) + #vmat[ia] += vmat_tmp + add_sparse(vmat[ia], vmat_tmp, mask) ao_dm0 = aow = None - t0 = log.timer_debug1('integration', *t0) + t1 = log.timer_debug2('integration', *t1) for ia in range(mol.natm): + vmat[ia] = take_last2d(vmat[ia], opt.rev_ao_idx) p0, p1 = aoslices[ia][2:] vmat[ia,:,:,:,p0:p1] += ipip[:,:,:,p0:p1] elif xctype == 'GGA': ao_deriv = 2 - comp = (ao_deriv+1)*(ao_deriv+2)*(ao_deriv+3)//6 + t1 = log.init_timer() for ao_mask, mask, weight, coords \ - in ni.block_loop(opt.mol, grids, nao, ao_deriv, max_memory, extra=5*comp*nao): - t0 = log.init_timer() + in ni.block_loop(opt.mol, grids, nao, ao_deriv, max_memory): nao_non0 = len(mask) ao = contract('nip,ij->njp', ao_mask, coeff[mask]) rho = numint.eval_rho2(opt.mol, ao[:4], mo_coeff, mo_occ, mask, xctype) - t0 = log.timer_debug1('eval rho', *t0) + t1 = log.timer_debug2('eval rho', *t1) vxc, fxc = ni.eval_xc_eff(mf.xc, rho, 2, xctype=xctype)[1:3] - t0 = log.timer_debug1('eval vxc', *t0) + t1 = log.timer_debug2('eval vxc', *t1) wv = weight * vxc wv[0] *= .5 aow = rks_grad._make_dR_dao_w(ao, wv) @@ -460,14 +462,19 @@ def _get_vxc_deriv2(hessobj, mo_coeff, mo_occ, max_memory): for i in range(3): aow = rks_grad._make_dR_dao_w(ao_mask, wv[i]) rks_grad._d1_dot_(aow, ao_mask[0].T, out=vmat_tmp[i]) - aow = [numint._scale_ao(ao_mask[:4], wv[i,:4]) for i in range(3)] + ng = len(weight) + aow = cupy.empty([3,nao_non0,ng]) + for i in range(3): + aow[i] = numint._scale_ao(ao_mask[:4], wv[i,:4]) _d1d2_dot_(vmat_tmp, mol, ao_mask[1:4], aow, mask, ao_loc, False) - vmat_tmp = contract('pi,xypq->xyiq', coeff[mask], vmat_tmp) - vmat_tmp = contract('qj,xyiq->xyij', coeff[mask], vmat_tmp) - vmat[ia] += vmat_tmp + #vmat_tmp = contract('pi,xypq->xyiq', coeff[mask], vmat_tmp) + #vmat_tmp = contract('qj,xyiq->xyij', coeff[mask], vmat_tmp) + #vmat[ia] += vmat_tmp + add_sparse(vmat[ia], vmat_tmp, mask) ao_dm0 = aow = None - t0 = log.timer_debug1('integration', *t0) + t1 = log.timer_debug2('integration', *t1) for ia in range(mol.natm): + vmat[ia] = take_last2d(vmat[ia], opt.rev_ao_idx) p0, p1 = aoslices[ia][2:] vmat[ia,:,:,:,p0:p1] += ipip[:,:,:,p0:p1] vmat[ia,:,:,:,p0:p1] += ipip[:,:,p0:p1].transpose(1,0,3,2) @@ -477,15 +484,15 @@ def _get_vxc_deriv2(hessobj, mo_coeff, mo_occ, max_memory): YX, YY, YZ = 5, 7, 8 ZX, ZY, ZZ = 6, 8, 9 ao_deriv = 2 + t1 = log.init_timer() for ao_mask, mask, weight, coords \ in ni.block_loop(opt.mol, grids, nao, ao_deriv, max_memory): - t0 = log.init_timer() nao_non0 = len(mask) ao = contract('nip,ij->njp', ao_mask, coeff[mask]) rho = numint.eval_rho2(opt.mol, ao[:10], mo_coeff, mo_occ, mask, xctype) - t0 = log.timer_debug1('eval rho', *t0) + t1 = log.timer_debug2('eval rho', *t1) vxc, fxc = ni.eval_xc_eff(mf.xc, rho, 2, xctype=xctype)[1:3] - t0 = log.timer_debug1('eval vxc', *t0) + t1 = log.timer_debug2('eval vxc', *t1) wv = weight * vxc wv[0] *= .5 wv[4] *= .25 @@ -522,11 +529,13 @@ def _get_vxc_deriv2(hessobj, mo_coeff, mo_occ, max_memory): _d1d2_dot_(vmat_tmp, mol, [ao_mask[YX], ao_mask[YY], ao_mask[YZ]], aow, mask, ao_loc, False) aow = [numint._scale_ao(ao_mask[3], wv[i,4]) for i in range(3)] _d1d2_dot_(vmat_tmp, mol, [ao_mask[ZX], ao_mask[ZY], ao_mask[ZZ]], aow, mask, ao_loc, False) - vmat_tmp = contract('pi,xypq->xyiq', coeff[mask], vmat_tmp) - vmat_tmp = contract('qj,xyiq->xyij', coeff[mask], vmat_tmp) - vmat[ia] += vmat_tmp - t0 = log.timer_debug1('integration', *t0) + #vmat_tmp = contract('pi,xypq->xyiq', coeff[mask], vmat_tmp) + #vmat_tmp = contract('qj,xyiq->xyij', coeff[mask], vmat_tmp) + #vmat[ia] += vmat_tmp + add_sparse(vmat[ia], vmat_tmp, mask) + t1 = log.timer_debug2('integration', *t1) for ia in range(mol.natm): + vmat[ia] = take_last2d(vmat[ia], opt.rev_ao_idx) p0, p1 = aoslices[ia][2:] vmat[ia,:,:,:,p0:p1] += ipip[:,:,:,p0:p1] vmat[ia,:,:,:,p0:p1] += ipip[:,:,p0:p1].transpose(1,0,3,2) @@ -570,14 +579,14 @@ def _get_vxc_deriv1(hessobj, mo_coeff, mo_occ, max_memory): max_memory = max(2000, max_memory-vmat.size*8/1e6) if xctype == 'LDA': ao_deriv = 1 + t1 = t0 = log.init_timer() for ao, mask, weight, coords \ in ni.block_loop(opt.mol, grids, nao, ao_deriv, max_memory): - t0 = log.init_timer() ao = contract('nip,ij->njp', ao, coeff[mask]) rho = numint.eval_rho2(opt.mol, ao[0], mo_coeff, mo_occ, mask, xctype) - t0 = log.timer_debug1('eval rho', *t0) + t1 = log.timer_debug2('eval rho', *t1) vxc, fxc = ni.eval_xc_eff(mf.xc, rho, 2, xctype=xctype)[1:3] - t0 = log.timer_debug1('eval vxc', *t0) + t1 = log.timer_debug2('eval vxc', *t1) wv = weight * vxc[0] aow = numint._scale_ao(ao[0], wv) v_ip += rks_grad._d1_dot_(ao[1:4], aow.T) @@ -592,17 +601,17 @@ def _get_vxc_deriv1(hessobj, mo_coeff, mo_occ, max_memory): aow = [numint._scale_ao(ao[0], wv[i]) for i in range(3)] vmat[ia] += rks_grad._d1_dot_(aow, ao[0].T) ao_dm0 = aow = None - t0 = log.timer_debug1('integration', *t0) + t1 = log.timer_debug2('integration', *t1) elif xctype == 'GGA': ao_deriv = 2 + t1 = t0 = log.init_timer() for ao, mask, weight, coords \ in ni.block_loop(mol, grids, nao, ao_deriv, max_memory): - t0 = log.init_timer() ao = contract('nip,ij->njp', ao, coeff[mask]) rho = numint.eval_rho2(mol, ao[:4], mo_coeff, mo_occ, mask, xctype) - t0 = log.timer_debug1('eval rho', *t0) + t1 = log.timer_debug2('eval rho', *t1) vxc, fxc = ni.eval_xc_eff(mf.xc, rho, 2, xctype=xctype)[1:3] - t0 = log.timer_debug1('eval vxc', *t0) + t1 = log.timer_debug2('eval vxc', *t1) wv = weight * vxc wv[0] *= .5 v_ip += rks_grad._gga_grad_sum_(ao, wv) @@ -616,20 +625,20 @@ def _get_vxc_deriv1(hessobj, mo_coeff, mo_occ, max_memory): wv[:,0] *= .5 aow = [numint._scale_ao(ao[:4], wv[i,:4]) for i in range(3)] vmat[ia] += rks_grad._d1_dot_(aow, ao[0].T) - t0 = log.timer_debug1('integration', *t0) + t1 = log.timer_debug2('integration', *t1) ao_dm0 = aow = None elif xctype == 'MGGA': if grids.level < 5: log.warn('MGGA Hessian is sensitive to dft grids.') ao_deriv = 2 + t1 = t0 = log.init_timer() for ao, mask, weight, coords \ in ni.block_loop(opt.mol, grids, nao, ao_deriv, max_memory): - t0 = log.init_timer() ao = contract('nip,ij->njp', ao, coeff[mask]) rho = numint.eval_rho2(opt.mol, ao[:10], mo_coeff, mo_occ, mask, xctype) - t0 = log.timer_debug1('eval rho', *t0) + t1 = log.timer_debug2('eval rho', *t1) vxc, fxc = ni.eval_xc_eff(mf.xc, rho, 2, xctype=xctype)[1:3] - t0 = log.timer_debug1('eval vxc', *t0) + t1 = log.timer_debug2('eval vxc', *t0) wv = weight * vxc wv[0] *= .5 wv[4] *= .5 # for the factor 1/2 in tau @@ -649,7 +658,7 @@ def _get_vxc_deriv1(hessobj, mo_coeff, mo_occ, max_memory): aow = [numint._scale_ao(ao[j], wv[i,4]) for i in range(3)] vmat[ia] += rks_grad._d1_dot_(aow, ao[j].T) ao_dm0 = aow = None - t0 = log.timer_debug1('integration', *t0) + t1 = log.timer_debug2('integration', *t1) for ia in range(mol.natm): p0, p1 = aoslices[ia][2:] vmat[ia,:,p0:p1] += v_ip[:,p0:p1] diff --git a/gpu4pyscf/lib/cupy_helper.py b/gpu4pyscf/lib/cupy_helper.py index bbaf2ebc..6a82e42d 100644 --- a/gpu4pyscf/lib/cupy_helper.py +++ b/gpu4pyscf/lib/cupy_helper.py @@ -165,9 +165,27 @@ def add_sparse(a, b, indices): ctypes.c_int(count) ) if err != 0: - raise RecursionError('failed in sparse_add2d') + raise RuntimeError('failed in sparse_add2d') return a +def dist_matrix(coords, out=None): + assert coords.flags.c_contiguous + n = coords.shape[0] + if out is None: + out = cupy.empty([n,n]) + + stream = cupy.cuda.get_current_stream() + err = libcupy_helper.dist_matrix( + ctypes.cast(stream.ptr, ctypes.c_void_p), + ctypes.cast(out.data.ptr, ctypes.c_void_p), + ctypes.cast(coords.data.ptr, ctypes.c_void_p), + ctypes.cast(coords.data.ptr, ctypes.c_void_p), + ctypes.c_int(n), + ) + if err != 0: + raise RuntimeError('failed in calculating distance matrix') + return out + def block_c2s_diag(ncart, nsph, angular, counts): ''' constract a cartesian to spherical transformation of n shells @@ -175,21 +193,18 @@ def block_c2s_diag(ncart, nsph, angular, counts): nshells = np.sum(counts) cart2sph = cupy.zeros([ncart, nsph]) - rows = [0] - cols = [0] + + rows = [np.array([0], dtype='int32')] + cols = [np.array([0], dtype='int32')] offsets = [] for l, count in zip(angular, counts): - for _ in range(count): - r, c = c2s_l[l].shape - rows.append(rows[-1] + r) - cols.append(cols[-1] + c) - offsets.append(c2s_offset[l]) - rows = np.asarray(rows, dtype='int32') - cols = np.asarray(cols, dtype='int32') - offsets = np.asarray(offsets, dtype='int32') + r, c = c2s_l[l].shape + rows.append(rows[-1][-1] + np.arange(1,count+1, dtype='int32') * r) + cols.append(cols[-1][-1] + np.arange(1,count+1, dtype='int32') * c) + offsets += [c2s_offset[l]] * count - rows = cupy.asarray(rows, dtype='int32') - cols = cupy.asarray(cols, dtype='int32') + rows = cupy.hstack(rows) + cols = cupy.hstack(cols) offsets = cupy.asarray(offsets, dtype='int32') stream = cupy.cuda.get_current_stream() @@ -300,8 +315,10 @@ def transpose_sum(a, stream=None): return a + a.transpose(0,2,1) ''' assert a.flags.c_contiguous - assert a.ndim == 3 n = a.shape[-1] + if a.ndim == 2: + a = a.reshape([-1,n,n]) + assert a.ndim == 3 count = a.shape[0] stream = cupy.cuda.get_current_stream() err = libcupy_helper.transpose_sum( diff --git a/gpu4pyscf/lib/cupy_helper/CMakeLists.txt b/gpu4pyscf/lib/cupy_helper/CMakeLists.txt index 3f59109e..445a10f6 100644 --- a/gpu4pyscf/lib/cupy_helper/CMakeLists.txt +++ b/gpu4pyscf/lib/cupy_helper/CMakeLists.txt @@ -17,13 +17,14 @@ #set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -arch=sm_80") -add_library(cupy_helper SHARED +add_library(cupy_helper SHARED transpose.cu block_diag.cu contract_cderi_k.cu take_last2d.cu async_d2h_2d.cu add_sparse.cu + dist_matrix.cu ) set_target_properties(cupy_helper PROPERTIES diff --git a/gpu4pyscf/lib/cupy_helper/dist_matrix.cu b/gpu4pyscf/lib/cupy_helper/dist_matrix.cu new file mode 100644 index 00000000..77f78de7 --- /dev/null +++ b/gpu4pyscf/lib/cupy_helper/dist_matrix.cu @@ -0,0 +1,49 @@ +/* Copyright 2023 The GPU4PySCF Authors. All Rights Reserved. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#include +#include +#define THREADS 32 + +__global__ +static void _calc_distances(double *dist, const double *x, const double *y, int n) +{ + int i = blockIdx.x * blockDim.x + threadIdx.x; + int j = blockIdx.y * blockDim.y + threadIdx.y; + if (i >= n || j >= n){ + return; + } + + double dx = x[3*i] - y[3*j]; + double dy = x[3*i+1] - y[3*j+1]; + double dz = x[3*i+2] - y[3*j+2]; + dist[i*n+j] = norm3d(dx, dy, dz); +} + +extern "C" { +int dist_matrix(cudaStream_t stream, double *dist, const double *x, const double *y, int n) +{ + int ntile = (n + THREADS - 1) / THREADS; + dim3 threads(THREADS, THREADS); + dim3 blocks(ntile, ntile); + _calc_distances<<>>(dist, x, y, n); + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + return 1; + } + return 0; +} +} diff --git a/gpu4pyscf/lib/gdft/gen_grids.cu b/gpu4pyscf/lib/gdft/gen_grids.cu index d71eedf9..67cb869e 100644 --- a/gpu4pyscf/lib/gdft/gen_grids.cu +++ b/gpu4pyscf/lib/gdft/gen_grids.cu @@ -67,7 +67,7 @@ int ngrids, int natm) dx = xi - xj_t; dy = yi - yj_t; dz = zi - zj_t; - double dij = norm3d(dx, dy, dz); + double dij = rnorm3d(dx, dy, dz); // distance between atom i and atom j dij_smem[tx] = dij; @@ -88,7 +88,7 @@ int ngrids, int natm) double dij = dij_smem[l]; double aij = a_smem[l]; - double g = (atom_i == atom_j) ? 0.0 : (dig - djg) / dij; + double g = (atom_i == atom_j) ? 0.0 : (dig - djg) * dij; // atomic radii adjust function double g1 = g*g - 1.0; diff --git a/gpu4pyscf/lib/gdft/nr_eval_gto.cu b/gpu4pyscf/lib/gdft/nr_eval_gto.cu index 1e2689d7..2f2139a3 100644 --- a/gpu4pyscf/lib/gdft/nr_eval_gto.cu +++ b/gpu4pyscf/lib/gdft/nr_eval_gto.cu @@ -28,7 +28,7 @@ #include "nr_eval_gto.cuh" #include "contract_rho.cuh" -#define NG_PER_BLOCK 128 +#define NG_PER_BLOCK 256 #define LMAX 8 #define GTO_MAX_CART 15 @@ -354,9 +354,10 @@ static void _cart_kernel_deriv1(BasOffsets offsets) double ce_2a = 0; for (int ip = 0; ip < offsets.nprim; ++ip) { double c = coeffs[ip]; - double e = exp(-exps[ip] * rr); + double exp_ip = exps[ip]; + double e = exp(-exp_ip * rr); ce += c * e; - ce_2a += c * e * exps[ip]; + ce_2a += c * e * exp_ip; } ce *= offsets.fac; ce_2a *= -2 * offsets.fac; @@ -1025,9 +1026,10 @@ static void _sph_kernel_deriv1(BasOffsets offsets) double ce_2a = 0; for (int ip = 0; ip < offsets.nprim; ++ip) { double c = coeffs[ip]; - double e = exp(-exps[ip] * rr); + double exp_ip = exps[ip]; + double e = exp(-exp_ip * rr); ce += c * e; - ce_2a += c * e * exps[ip]; + ce_2a += c * e * exp_ip; } ce *= offsets.fac; ce_2a *= -2 * offsets.fac; diff --git a/gpu4pyscf/lib/gint/g3c2e.cu b/gpu4pyscf/lib/gint/g3c2e.cu index 2395fbbe..c0f0f425 100644 --- a/gpu4pyscf/lib/gint/g3c2e.cu +++ b/gpu4pyscf/lib/gint/g3c2e.cu @@ -5,7 +5,7 @@ void GINTfill_int3c2e_kernel(GINTEnvVars envs, ERITensor eri, BasisProdOffsets o int ntasks_kl = offsets.ntasks_kl; int task_ij = blockIdx.x * blockDim.x + threadIdx.x; int task_kl = blockIdx.y * blockDim.y + threadIdx.y; - + if (task_ij >= ntasks_ij || task_kl >= ntasks_kl) { return; } @@ -26,7 +26,7 @@ void GINTfill_int3c2e_kernel(GINTEnvVars envs, ERITensor eri, BasisProdOffsets o int lsh = bas_pair2ket[bas_kl]; double uw[NROOTS*2]; double g[GSIZE_INT3C]; - + double* __restrict__ a12 = c_bpcache.a12; double* __restrict__ x12 = c_bpcache.x12; double* __restrict__ y12 = c_bpcache.y12; @@ -48,9 +48,9 @@ void GINTfill_int3c2e_kernel(GINTEnvVars envs, ERITensor eri, BasisProdOffsets o as_ksh = lsh; as_lsh = ksh; } + GINTmemset_int3c2e(envs, eri, ish, jsh, ksh); for (ij = prim_ij; ij < prim_ij+nprim_ij; ++ij) { for (kl = prim_kl; kl < prim_kl+nprim_kl; ++kl) { - double aij = a12[ij]; double xij = x12[ij]; double yij = y12[ij]; @@ -65,7 +65,7 @@ void GINTfill_int3c2e_kernel(GINTEnvVars envs, ERITensor eri, BasisProdOffsets o double aijkl = aij + akl; double a1 = aij * akl; double a0 = a1 / aijkl; - double theta = omega > 0.0 ? omega * omega / (omega * omega + a0) : 1.0; + double theta = omega > 0.0 ? omega * omega / (omega * omega + a0) : 1.0; a0 *= theta; double x = a0 * (xijxkl * xijxkl + yijykl * yijykl + zijzkl * zijzkl); GINTrys_root(x, uw); diff --git a/gpu4pyscf/lib/gint/gout3c2e.cu b/gpu4pyscf/lib/gint/gout3c2e.cu index fb4ff75e..c6fb1eba 100644 --- a/gpu4pyscf/lib/gint/gout3c2e.cu +++ b/gpu4pyscf/lib/gint/gout3c2e.cu @@ -235,15 +235,18 @@ static void GINTwrite_int3c2e_ipip_direct(GINTEnvVars envs, ERITensor eri, doubl double eri_zy = 0; double eri_zz = 0; for (int ir = 0; ir < NROOTS; ++ir){ - eri_xx += g3[ix + ir] * g0[iy + ir] * g0[iz + ir]; - eri_xy += g2[ix + ir] * g1[iy + ir] * g0[iz + ir]; - eri_xz += g2[ix + ir] * g0[iy + ir] * g1[iz + ir]; - eri_yx += g1[ix + ir] * g2[iy + ir] * g0[iz + ir]; - eri_yy += g0[ix + ir] * g3[iy + ir] * g0[iz + ir]; - eri_yz += g0[ix + ir] * g2[iy + ir] * g1[iz + ir]; - eri_zx += g1[ix + ir] * g0[iy + ir] * g2[iz + ir]; - eri_zy += g0[ix + ir] * g1[iy + ir] * g2[iz + ir]; - eri_zz += g0[ix + ir] * g0[iy + ir] * g3[iz + ir]; + double g0_x = g0[ix + ir]; + double g0_y = g0[iy + ir]; + double g0_z = g0[iz + ir]; + eri_xx += g3[ix + ir] * g0_y * g0_z ; + eri_xy += g2[ix + ir] * g1[iy + ir] * g0_z ; + eri_xz += g2[ix + ir] * g0_y * g1[iz + ir]; + eri_yx += g1[ix + ir] * g2[iy + ir] * g0_z ; + eri_yy += g0_x * g3[iy + ir] * g0_z ; + eri_yz += g0_x * g2[iy + ir] * g1[iz + ir]; + eri_zx += g1[ix + ir] * g0_y * g2[iz + ir]; + eri_zy += g0_x * g1[iy + ir] * g2[iz + ir]; + eri_zz += g0_x * g0_y * g3[iz + ir]; } off = i+jstride*j; pxx_eri[off] += eri_xx; @@ -314,13 +317,38 @@ static void GINTwrite_int3c2e_ip_direct(GINTEnvVars envs, ERITensor eri, double* } } +template __device__ +static void GINTmemset_int3c2e(GINTEnvVars envs, ERITensor eri, int ish, int jsh, int ksh) +{ + int *ao_loc = c_bpcache.ao_loc; + size_t jstride = eri.stride_j; + size_t kstride = eri.stride_k; + int i0 = ao_loc[ish ] - eri.ao_offsets_i; + int i1 = ao_loc[ish+1] - eri.ao_offsets_i; + int j0 = ao_loc[jsh ] - eri.ao_offsets_j; + int j1 = ao_loc[jsh+1] - eri.ao_offsets_j; + int k0 = ao_loc[ksh ] - eri.ao_offsets_k; + int k1 = ao_loc[ksh+1] - eri.ao_offsets_k; + int i, j, k, n; + double* __restrict__ p_eri; + + for (n = 0, k = k0; k < k1; ++k) { + p_eri = eri.data + k * kstride; + + for (j = j0; j < j1; ++j) { + for (i = i0; i < i1; ++i, ++n) { + p_eri[i+jstride*j] = 0; + } + } + } +} + template __device__ static void GINTwrite_int3c2e_direct(GINTEnvVars envs, ERITensor eri, double* g, int ish, int jsh, int ksh) { int *ao_loc = c_bpcache.ao_loc; size_t jstride = eri.stride_j; size_t kstride = eri.stride_k; - size_t lstride = eri.stride_l; int i0 = ao_loc[ish ] - eri.ao_offsets_i; int i1 = ao_loc[ish+1] - eri.ao_offsets_i; int j0 = ao_loc[jsh ] - eri.ao_offsets_j; @@ -337,7 +365,7 @@ static void GINTwrite_int3c2e_direct(GINTEnvVars envs, ERITensor eri, double* g, int ix, iy, iz, off; for (n = 0, k = k0; k < k1; ++k) { - p_eri = eri.data + 0 * lstride + k * kstride; + p_eri = eri.data + k * kstride; for (j = j0; j < j1; ++j) { for (i = i0; i < i1; ++i, ++n) { diff --git a/gpu4pyscf/lib/gint/nr_fill_ao_int3c2e.cu b/gpu4pyscf/lib/gint/nr_fill_ao_int3c2e.cu index 0953fc10..bcb5e382 100644 --- a/gpu4pyscf/lib/gint/nr_fill_ao_int3c2e.cu +++ b/gpu4pyscf/lib/gint/nr_fill_ao_int3c2e.cu @@ -41,7 +41,7 @@ static int GINTfill_int3c2e_tasks(ERITensor *eri, BasisProdOffsets *offsets, GIN int ntasks_kl = offsets->ntasks_kl; assert(ntasks_kl < 65536*THREADSY); int type_ijkl; - + dim3 threads(THREADSX, THREADSY); dim3 blocks((ntasks_ij+THREADSX-1)/THREADSX, (ntasks_kl+THREADSY-1)/THREADSY); switch (nrys_roots) { @@ -121,14 +121,14 @@ int GINTfill_int3c2e(cudaStream_t stream, BasisProdCache *bpcache, double *eri, ContractionProdType *cp_kl = bpcache->cptype + cp_kl_id; GINTEnvVars envs; int ng[4] = {0,0,0,0}; - + GINTinit_EnvVars(&envs, cp_ij, cp_kl, ng); envs.omega = omega; if (envs.nrys_roots > 8) { return 2; } - + // TODO: improve the efficiency by unrolling if (envs.nrys_roots > 1) { int16_t *idx4c = (int16_t *)malloc(sizeof(int16_t) * envs.nf * 3); @@ -136,9 +136,9 @@ int GINTfill_int3c2e(cudaStream_t stream, BasisProdCache *bpcache, double *eri, checkCudaErrors(cudaMemcpyToSymbol(c_idx4c, idx4c, sizeof(int16_t)*envs.nf*3)); free(idx4c); } - + int kl_bin, ij_bin1; - + //checkCudaErrors(cudaMemcpyToSymbol(c_envs, &envs, sizeof(GINTEnvVars))); // move bpcache to constant memory checkCudaErrors(cudaMemcpyToSymbol(c_bpcache, bpcache, sizeof(BasisProdCache))); @@ -154,7 +154,7 @@ int GINTfill_int3c2e(cudaStream_t stream, BasisProdCache *bpcache, double *eri, eritensor.nao = nao; eritensor.data = eri; BasisProdOffsets offsets; - + int *bas_pairs_locs = bpcache->bas_pairs_locs; int *primitive_pairs_locs = bpcache->primitive_pairs_locs; for (kl_bin = 0; kl_bin < nbins; kl_bin++) { @@ -182,12 +182,12 @@ int GINTfill_int3c2e(cudaStream_t stream, BasisProdCache *bpcache, double *eri, int err = -1; err = GINTfill_int3c2e_tasks(&eritensor, &offsets, &envs, stream); - + if (err != 0) { return err; } } - + return 0; } } diff --git a/gpu4pyscf/lib/gvhf/g3c2e.cuh b/gpu4pyscf/lib/gvhf/g3c2e.cuh index c82d037c..9195cddd 100644 --- a/gpu4pyscf/lib/gvhf/g3c2e.cuh +++ b/gpu4pyscf/lib/gvhf/g3c2e.cuh @@ -58,7 +58,7 @@ static void GINTkernel_int3c2e_ip1_getjk(JKMatrix jk, double* __restrict__ gout, double sx = gout[3*n]; double sy = gout[3*n + 1]; double sz = gout[3*n + 2]; - + double rhoj_tmp = dm[off_dm] * rhoj[k]; double rhok_tmp = rhok[off_rhok]; @@ -73,7 +73,7 @@ static void GINTkernel_int3c2e_ip1_getjk(JKMatrix jk, double* __restrict__ gout, } } } - + for (i = i0; i < i1; ++i){ int ii = 3*(i-i0); atomicAdd(vj + i + 0*nao, j3[ii + 0]); @@ -107,22 +107,22 @@ static void GINTkernel_int3c2e_ip2_getjk(JKMatrix jk, double* __restrict__ gout, double* __restrict__ dm = jk.dm; double j3[GPU_CART_MAX * 3]; double k3[GPU_CART_MAX * 3]; - + for (k = 0; k < (k1-k0) * 3; k++){ j3[k] = 0.0; k3[k] = 0.0; } - + for (n = 0, k = k0; k < k1; ++k) { for (j = j0; j < j1; ++j) { for (i = i0; i < i1; ++i, ++n) { off_dm = i + nao*j; off_rhok = i + nao*j + k*nao*nao; - + double sx = gout[3 * n]; double sy = gout[3 * n + 1]; double sz = gout[3 * n + 2]; - + double rhoj_tmp = dm[off_dm] * rhoj[k]; double rhok_tmp = rhok[off_rhok]; @@ -143,7 +143,7 @@ static void GINTkernel_int3c2e_ip2_getjk(JKMatrix jk, double* __restrict__ gout, atomicAdd(vj + k + 0*naux, j3[kk + 0]); atomicAdd(vj + k + 1*naux, j3[kk + 1]); atomicAdd(vj + k + 2*naux, j3[kk + 2]); - + atomicAdd(vk + k + 0*naux, k3[kk + 0]); atomicAdd(vk + k + 1*naux, k3[kk + 1]); atomicAdd(vk + k + 2*naux, k3[kk + 2]); @@ -169,7 +169,7 @@ static void GINTkernel_int3c2e_getj_pass1(GINTEnvVars envs, JKMatrix jk, double* int i, j, k; double* __restrict__ rhoj = jk.rhoj; double* __restrict__ dm = jk.dm; - + int i_l = envs.i_l; int j_l = envs.j_l; int k_l = envs.k_l; @@ -243,11 +243,11 @@ static void GINTkernel_int3c2e_getj_pass2(GINTEnvVars envs, JKMatrix jk, double* double vj_tmp = 0.0; for (k = k0; k < k1; ++k){ int kp = k - k0; - + int loc_k = c_l_locs[k_l] + kp; int loc_j = c_l_locs[j_l] + jp; int loc_i = c_l_locs[i_l] + ip; - + int ix = dk * idx[loc_k] + dj * idx[loc_j] + di * idx[loc_i]; int iy = dk * idy[loc_k] + dj * idy[loc_j] + di * idy[loc_i] + envs.g_size; int iz = dk * idz[loc_k] + dj * idz[loc_j] + di * idz[loc_i] + envs.g_size * 2; @@ -306,7 +306,7 @@ static void GINTkernel_int3c2e_ip1_getjk_direct(GINTEnvVars envs, JKMatrix jk, d int jp = j - j0; for (i = i0; i < i1; ++i) { int ip = i - i0; - + int loc_k = c_l_locs[k_l] + kp; int loc_j = c_l_locs[j_l] + jp; int loc_i = c_l_locs[i_l] + ip; @@ -327,7 +327,7 @@ static void GINTkernel_int3c2e_ip1_getjk_direct(GINTEnvVars envs, JKMatrix jk, d sy += gx * f[iy + ir] * gz; sz += gx * gy * f[iz + ir]; } - + int ii = 3*(i-i0); off_rhok = i + nao*j + k*nao*nao; double rhok_tmp = rhok[off_rhok]; @@ -348,7 +348,7 @@ static void GINTkernel_int3c2e_ip1_getjk_direct(GINTEnvVars envs, JKMatrix jk, d int jp = j - j0; for (i = i0; i < i1; ++i) { int ip = i - i0; - + int loc_k = c_l_locs[k_l] + kp; int loc_j = c_l_locs[j_l] + jp; int loc_i = c_l_locs[i_l] + ip; @@ -388,7 +388,7 @@ static void GINTkernel_int3c2e_ip1_getjk_direct(GINTEnvVars envs, JKMatrix jk, d int jp = j - j0; for (i = i0; i < i1; ++i) { int ip = i - i0; - + int loc_k = c_l_locs[k_l] + kp; int loc_j = c_l_locs[j_l] + jp; int loc_i = c_l_locs[i_l] + ip; @@ -453,7 +453,7 @@ static void GINTkernel_int3c2e_ip2_getjk_direct(GINTEnvVars envs, JKMatrix jk, d double* __restrict__ rhoj = jk.rhoj; double* __restrict__ rhok = jk.rhok; double* __restrict__ dm = jk.dm; - + int i_l = envs.i_l; int j_l = envs.j_l; int k_l = envs.k_l; @@ -468,7 +468,7 @@ static void GINTkernel_int3c2e_ip2_getjk_direct(GINTEnvVars envs, JKMatrix jk, d int jp = j - j0; for (i = i0; i < i1; ++i) { int ip = i - i0; - + int loc_k = c_l_locs[k_l] + kp; int loc_j = c_l_locs[j_l] + jp; int loc_i = c_l_locs[i_l] + ip; @@ -476,7 +476,7 @@ static void GINTkernel_int3c2e_ip2_getjk_direct(GINTEnvVars envs, JKMatrix jk, d int ix = dk * idx[loc_k] + dj * idx[loc_j] + di * idx[loc_i]; int iy = dk * idy[loc_k] + dj * idy[loc_j] + di * idy[loc_i] + g_size; int iz = dk * idz[loc_k] + dj * idz[loc_j] + di * idz[loc_i] + g_size * 2; - + double sx = 0.0; double sy = 0.0; double sz = 0.0; @@ -510,7 +510,7 @@ static void GINTkernel_int3c2e_ip2_getjk_direct(GINTEnvVars envs, JKMatrix jk, d int jp = j - j0; for (i = i0; i < i1; ++i) { int ip = i - i0; - + int loc_k = c_l_locs[k_l] + kp; int loc_j = c_l_locs[j_l] + jp; int loc_i = c_l_locs[i_l] + ip; @@ -518,7 +518,7 @@ static void GINTkernel_int3c2e_ip2_getjk_direct(GINTEnvVars envs, JKMatrix jk, d int ix = dk * idx[loc_k] + dj * idx[loc_j] + di * idx[loc_i]; int iy = dk * idy[loc_k] + dj * idy[loc_j] + di * idy[loc_i] + g_size; int iz = dk * idz[loc_k] + dj * idz[loc_j] + di * idz[loc_i] + g_size * 2; - + double sx = 0.0; double sy = 0.0; double sz = 0.0; @@ -551,7 +551,7 @@ static void GINTkernel_int3c2e_ip2_getjk_direct(GINTEnvVars envs, JKMatrix jk, d int jp = j - j0; for (i = i0; i < i1; ++i) { int ip = i - i0; - + int loc_k = c_l_locs[k_l] + kp; int loc_j = c_l_locs[j_l] + jp; int loc_i = c_l_locs[i_l] + ip; @@ -559,7 +559,7 @@ static void GINTkernel_int3c2e_ip2_getjk_direct(GINTEnvVars envs, JKMatrix jk, d int ix = dk * idx[loc_k] + dj * idx[loc_j] + di * idx[loc_i]; int iy = dk * idy[loc_k] + dj * idy[loc_j] + di * idy[loc_i] + g_size; int iz = dk * idz[loc_k] + dj * idz[loc_j] + di * idz[loc_i] + g_size * 2; - + double sx = 0.0; double sy = 0.0; double sz = 0.0; @@ -579,7 +579,7 @@ static void GINTkernel_int3c2e_ip2_getjk_direct(GINTEnvVars envs, JKMatrix jk, d j3[kk + 0] += sx * rhoj_tmp; j3[kk + 1] += sy * rhoj_tmp; j3[kk + 2] += sz * rhoj_tmp; - + off_rhok = i + nao*j + k*nao*nao; double rhok_tmp = rhok[off_rhok]; k3[kk + 0] += sx * rhok_tmp; @@ -602,7 +602,7 @@ static void write_int3c2e_ip1_jk(JKMatrix jk, double* j3, double* k3, int ish){ int tx = threadIdx.x; int ty = threadIdx.y; __shared__ double sdata[THREADSX][THREADSY]; - + if (vj != NULL){ for (int i = i0; i < i1; ++i){ for (int j = 0; j < 3; j++){ @@ -644,7 +644,7 @@ static void write_int3c2e_ip2_jk(JKMatrix jk, double *j3, double* k3, int ksh){ int tx = threadIdx.x; int ty = threadIdx.y; __shared__ double sdata[THREADSX][THREADSY]; - + if (vj != NULL){ for (int k = k0; k < k1; ++k){ for (int j = 0; j < 3; j++){ @@ -741,7 +741,7 @@ static void GINTrun_int3c2e_ip1_jk_kernel1000(GINTEnvVars envs, JKMatrix jk, Bas double aijkl = aij + akl; double a1 = aij * akl; double a0 = a1 / aijkl; - double theta = omega > 0.0 ? omega * omega / (omega * omega + a0) : 1.0; + double theta = omega > 0.0 ? omega * omega / (omega * omega + a0) : 1.0; a0 *= theta; double x = a0 * (xijxkl * xijxkl + yijykl * yijykl + zijzkl * zijzkl); double fac = eij * ekl * sqrt(a0 / (a1 * a1 * a1)); @@ -770,11 +770,11 @@ static void GINTrun_int3c2e_ip1_jk_kernel1000(GINTEnvVars envs, JKMatrix jk, Bas double g_3 = c00y; double g_4 = norm * fac * weight0; double g_5 = g_4 * c00z; - + double f_1 = ai2 * g_1; double f_3 = ai2 * g_3; double f_5 = ai2 * g_5; - + gout0 += f_1 * g_2 * g_4; gout1 += g_0 * f_3 * g_4; gout2 += g_0 * g_2 * f_5; @@ -784,14 +784,14 @@ static void GINTrun_int3c2e_ip1_jk_kernel1000(GINTEnvVars envs, JKMatrix jk, Bas int i0 = ao_loc[ish] - jk.ao_offsets_i; int j0 = ao_loc[jsh] - jk.ao_offsets_j; int k0 = ao_loc[ksh] - jk.ao_offsets_k; - + int nao = jk.nao; double* __restrict__ dm = jk.dm; double* __restrict__ rhok = jk.rhok; double* __restrict__ rhoj = jk.rhoj; double* __restrict__ vj = jk.vj; double* __restrict__ vk = jk.vk; - + int tx = threadIdx.x; int ty = threadIdx.y; __shared__ double sdata[THREADSX][THREADSY]; @@ -903,7 +903,7 @@ static void GINTrun_int3c2e_ip2_jk_kernel0010(GINTEnvVars envs, JKMatrix jk, Bas double aijkl = aij + akl; double a1 = aij * akl; double a0 = a1 / aijkl; - double theta = omega > 0.0 ? omega * omega / (omega * omega + a0) : 1.0; + double theta = omega > 0.0 ? omega * omega / (omega * omega + a0) : 1.0; a0 *= theta; double x = a0 * (xijxkl * xijxkl + yijykl * yijykl + zijzkl * zijzkl); double fac = norm * eij * ekl * sqrt(a0 / (a1 * a1 * a1)); diff --git a/gpu4pyscf/lib/gvhf/g3c2e_pass1.cu b/gpu4pyscf/lib/gvhf/g3c2e_pass1.cu index 0bd30319..cc100d15 100644 --- a/gpu4pyscf/lib/gvhf/g3c2e_pass1.cu +++ b/gpu4pyscf/lib/gvhf/g3c2e_pass1.cu @@ -13,7 +13,7 @@ * You should have received a copy of the GNU General Public License * along with this program. If not, see . */ - + template __global__ void GINTint3c2e_pass1_j_kernel(GINTEnvVars envs, JKMatrix jk, BasisProdOffsets offsets) @@ -22,7 +22,7 @@ void GINTint3c2e_pass1_j_kernel(GINTEnvVars envs, JKMatrix jk, BasisProdOffsets int ntasks_kl = offsets.ntasks_kl; int task_ij = blockIdx.x * blockDim.x + threadIdx.x; int task_kl = blockIdx.y * blockDim.y + threadIdx.y; - + if (task_ij >= ntasks_ij || task_kl >= ntasks_kl) { return; } @@ -42,12 +42,16 @@ void GINTint3c2e_pass1_j_kernel(GINTEnvVars envs, JKMatrix jk, BasisProdOffsets int lsh = bas_pair2ket[bas_kl]; double uw[NROOTS*2]; double g[GSIZE]; - + double* __restrict__ a12 = c_bpcache.a12; double* __restrict__ x12 = c_bpcache.x12; double* __restrict__ y12 = c_bpcache.y12; double* __restrict__ z12 = c_bpcache.z12; + if (ish == jsh){ + norm *= .5; + } + int ij, kl; int as_ish, as_jsh, as_ksh, as_lsh; if (envs.ibase) { @@ -70,7 +74,7 @@ void GINTint3c2e_pass1_j_kernel(GINTEnvVars envs, JKMatrix jk, BasisProdOffsets double xij = x12[ij]; double yij = y12[ij]; double zij = z12[ij]; - for (kl = prim_kl; kl < prim_kl+nprim_kl; ++kl) { + for (kl = prim_kl; kl < prim_kl+nprim_kl; ++kl) { double akl = a12[kl]; double xkl = x12[kl]; double ykl = y12[kl]; diff --git a/gpu4pyscf/lib/gvhf/g3c2e_pass1_root1.cu b/gpu4pyscf/lib/gvhf/g3c2e_pass1_root1.cu index cb9dbf11..1216726a 100644 --- a/gpu4pyscf/lib/gvhf/g3c2e_pass1_root1.cu +++ b/gpu4pyscf/lib/gvhf/g3c2e_pass1_root1.cu @@ -41,6 +41,9 @@ static void GINTint3c2e_pass1_j_kernel0000(GINTEnvVars envs, JKMatrix jk, BasisP double* __restrict__ x12 = c_bpcache.x12; double* __restrict__ y12 = c_bpcache.y12; double* __restrict__ z12 = c_bpcache.z12; + if (ish == jsh){ + norm *= .5; + } int ij, kl; double gout0 = 0; for (ij = prim_ij; ij < prim_ij+nprim_ij; ++ij) { @@ -70,7 +73,7 @@ static void GINTint3c2e_pass1_j_kernel0000(GINTEnvVars envs, JKMatrix jk, BasisP } gout0 += fac; } } - + int *ao_loc = c_bpcache.ao_loc; int nao = jk.nao; int i0 = ao_loc[ish] - jk.ao_offsets_i; @@ -107,6 +110,9 @@ static void GINTint3c2e_pass1_j_kernel0010(GINTEnvVars envs, JKMatrix jk, BasisP double* __restrict__ x12 = c_bpcache.x12; double* __restrict__ y12 = c_bpcache.y12; double* __restrict__ z12 = c_bpcache.z12; + if (ish == jsh){ + norm *= .5; + } int ij, kl; int prim_ij0, prim_ij1, prim_kl0, prim_kl1; int nbas = c_bpcache.nbas; @@ -215,6 +221,9 @@ static void GINTint3c2e_pass1_j_kernel1000(GINTEnvVars envs, JKMatrix jk, BasisP double* __restrict__ x12 = c_bpcache.x12; double* __restrict__ y12 = c_bpcache.y12; double* __restrict__ z12 = c_bpcache.z12; + if (ish == jsh){ + norm *= .5; + } int ij, kl; int prim_ij0, prim_ij1, prim_kl0, prim_kl1; int nbas = c_bpcache.nbas; diff --git a/gpu4pyscf/lib/gvhf/g3c2e_pass2.cu b/gpu4pyscf/lib/gvhf/g3c2e_pass2.cu index 83d37261..6e395ce7 100644 --- a/gpu4pyscf/lib/gvhf/g3c2e_pass2.cu +++ b/gpu4pyscf/lib/gvhf/g3c2e_pass2.cu @@ -13,7 +13,7 @@ * You should have received a copy of the GNU General Public License * along with this program. If not, see . */ - + template __global__ void GINTint3c2e_pass2_j_kernel(GINTEnvVars envs, JKMatrix jk, BasisProdOffsets offsets) @@ -22,7 +22,7 @@ void GINTint3c2e_pass2_j_kernel(GINTEnvVars envs, JKMatrix jk, BasisProdOffsets int ntasks_kl = offsets.ntasks_kl; int task_ij = blockIdx.x * blockDim.x + threadIdx.x; int task_kl = blockIdx.y * blockDim.y + threadIdx.y; - + if (task_ij >= ntasks_ij || task_kl >= ntasks_kl) { return; } @@ -42,7 +42,9 @@ void GINTint3c2e_pass2_j_kernel(GINTEnvVars envs, JKMatrix jk, BasisProdOffsets int lsh = bas_pair2ket[bas_kl]; double uw[NROOTS*2]; double g[GSIZE]; - + if (ish == jsh){ + norm *= .5; + } double* __restrict__ a12 = c_bpcache.a12; double* __restrict__ x12 = c_bpcache.x12; double* __restrict__ y12 = c_bpcache.y12; @@ -70,7 +72,7 @@ void GINTint3c2e_pass2_j_kernel(GINTEnvVars envs, JKMatrix jk, BasisProdOffsets double xij = x12[ij]; double yij = y12[ij]; double zij = z12[ij]; - for (kl = prim_kl; kl < prim_kl+nprim_kl; ++kl) { + for (kl = prim_kl; kl < prim_kl+nprim_kl; ++kl) { double akl = a12[kl]; double xkl = x12[kl]; double ykl = y12[kl]; diff --git a/gpu4pyscf/lib/gvhf/g3c2e_pass2_root1.cu b/gpu4pyscf/lib/gvhf/g3c2e_pass2_root1.cu index 2f76578c..47c65d3f 100644 --- a/gpu4pyscf/lib/gvhf/g3c2e_pass2_root1.cu +++ b/gpu4pyscf/lib/gvhf/g3c2e_pass2_root1.cu @@ -42,6 +42,9 @@ static void GINTint3c2e_pass2_j_kernel0000(GINTEnvVars envs, JKMatrix jk, BasisP double* __restrict__ x12 = c_bpcache.x12; double* __restrict__ y12 = c_bpcache.y12; double* __restrict__ z12 = c_bpcache.z12; + if (ish == jsh){ + norm *= .5; + } int ij, kl; double gout0 = 0; for (ij = prim_ij; ij < prim_ij+nprim_ij; ++ij) { @@ -71,7 +74,7 @@ static void GINTint3c2e_pass2_j_kernel0000(GINTEnvVars envs, JKMatrix jk, BasisP } gout0 += fac; } } - + int *ao_loc = c_bpcache.ao_loc; int nao = jk.nao; int i0 = ao_loc[ish] - jk.ao_offsets_i; @@ -113,7 +116,9 @@ static void GINTint3c2e_pass2_j_kernel0010(GINTEnvVars envs, JKMatrix jk, BasisP double* __restrict__ bas_x = c_bpcache.bas_coords; double* __restrict__ bas_y = bas_x + nbas; double* __restrict__ bas_z = bas_y + nbas; - + if (ish == jsh){ + norm *= .5; + } double gout0 = 0; double gout1 = 0; double gout2 = 0; @@ -219,7 +224,9 @@ static void GINTint3c2e_pass2_j_kernel1000(GINTEnvVars envs, JKMatrix jk, BasisP double* __restrict__ bas_x = c_bpcache.bas_coords; double* __restrict__ bas_y = bas_x + nbas; double* __restrict__ bas_z = bas_y + nbas; - + if (ish == jsh){ + norm *= .5; + } double gout0 = 0; double gout1 = 0; double gout2 = 0; diff --git a/gpu4pyscf/lib/logger.py b/gpu4pyscf/lib/logger.py index 7b46fb27..5ca2ec13 100644 --- a/gpu4pyscf/lib/logger.py +++ b/gpu4pyscf/lib/logger.py @@ -25,6 +25,7 @@ WARN = lib.logger.WARN DEBUG = lib.logger.DEBUG DEBUG1= lib.logger.DEBUG1 +DEBUG2= lib.logger.DEBUG2 TIMER_LEVEL = lib.logger.TIMER_LEVEL flush = lib.logger.flush @@ -84,17 +85,24 @@ def _timer_debug1(rec, msg, cpu0=None, wall0=None, gpu0=None, sync=True): rec._t0 = process_clock() return rec._t0, +def _timer_debug2(rec, msg, cpu0=None, wall0=None, gpu0=None, sync=True): + if rec.verbose >= DEBUG2: + return timer(rec, msg, cpu0, wall0, gpu0) + return cpu0, wall0, gpu0 + info = lib.logger.info note = lib.logger.note debug = lib.logger.debug debug1 = lib.logger.debug1 debug2 = lib.logger.debug2 timer_debug1 = _timer_debug1 +timer_debug2 = _timer_debug2 class Logger(lib.logger.Logger): def __init__(self, stdout=sys.stdout, verbose=NOTE): super().__init__(stdout=stdout, verbose=verbose) timer_debug1 = _timer_debug1 + timer_debug2 = _timer_debug2 timer = timer init_timer = init_timer diff --git a/gpu4pyscf/lib/tests/test_cupy_helper.py b/gpu4pyscf/lib/tests/test_cupy_helper.py index fed01456..befbf827 100644 --- a/gpu4pyscf/lib/tests/test_cupy_helper.py +++ b/gpu4pyscf/lib/tests/test_cupy_helper.py @@ -18,7 +18,7 @@ import cupy from gpu4pyscf.lib.cupy_helper import ( take_last2d, transpose_sum, krylov, unpack_sparse, - add_sparse, takebak, empty_mapped) + add_sparse, takebak, empty_mapped, dist_matrix) class KnownValues(unittest.TestCase): def test_take_last2d(self): @@ -69,6 +69,13 @@ def test_sparse(self): add_sparse(a, b, indices) assert cupy.linalg.norm(a - a0) < 1e-10 + def test_dist_matrix(self): + a = cupy.random.rand(4, 3) + rij = cupy.sum((a[:,None,:] - a[None,:,:])**2, axis=2)**0.5 + + rij0 = dist_matrix(a) + assert cupy.linalg.norm(rij - rij0) < 1e-10 + def test_takebak(self): a = empty_mapped((5, 8)) a[:] = 1. diff --git a/gpu4pyscf/solvent/grad/pcm.py b/gpu4pyscf/solvent/grad/pcm.py index be469508..cf6cda8f 100644 --- a/gpu4pyscf/solvent/grad/pcm.py +++ b/gpu4pyscf/solvent/grad/pcm.py @@ -114,6 +114,7 @@ def get_dD_dS(surface, dF, with_S=True, with_D=False): rij = cupy.linalg.norm(ri_rj, axis=-1) xi_r_ij = xi_ij * rij cupy.fill_diagonal(rij, 1) + xi_i = xi_j = None dS_dr = -(scipy.special.erf(xi_r_ij) - 2.0*xi_r_ij/PI**0.5*cupy.exp(-xi_r_ij**2))/rij**2 cupy.fill_diagonal(dS_dr, 0) @@ -134,13 +135,16 @@ def get_dD_dS(surface, dF, with_S=True, with_D=False): dD_dri = cupy.expand_dims(dD_dri, axis=-1) dD = dD_dri * drij + dS_dr * (-nj/rij + 3.0*nj_rij/rij**2 * drij) - + dD_dri = None dSii_dF = -exponents * (2.0/PI)**0.5 / switch_fun**2 dSii = cupy.expand_dims(dSii_dF, axis=(1,2)) * dF return dD, dS, dSii def grad_nuc(pcmobj, dm): + mol = pcmobj.mol + log = logger.new_logger(mol, mol.verbose) + t1 = log.init_timer() if not pcmobj._intermediates or 'q_sym' not in pcmobj._intermediates: pcmobj._get_vind(dm) @@ -168,6 +172,7 @@ def grad_nuc(pcmobj, dm): dv_g = numpy.einsum('gx,g->gx', dv_g, q_sym) de -= numpy.asarray([numpy.sum(dv_g[p0:p1], axis=0) for p0,p1 in gridslice]) + t1 = log.timer_debug1('grad nuc', *t1) return de def grad_qv(pcmobj, dm): @@ -176,7 +181,9 @@ def grad_qv(pcmobj, dm): ''' if not pcmobj._intermediates or 'q_sym' not in pcmobj._intermediates: pcmobj._get_vind(dm) - + mol = pcmobj.mol + log = logger.new_logger(mol, mol.verbose) + t1 = log.init_timer() gridslice = pcmobj.surface['gslice_by_atom'] q_sym = pcmobj._intermediates['q_sym'] @@ -199,6 +206,7 @@ def grad_qv(pcmobj, dm): dq = cupy.asarray([cupy.sum(dq[:,p0:p1], axis=1) for p0,p1 in gridslice]) dvj= 2.0 * cupy.asarray([cupy.sum(dvj[:,p0:p1], axis=1) for p0,p1 in aoslice[:,2:]]) de = dq + dvj + t1 = log.timer_debug1('grad qv', *t1) return de.get() def grad_solver(pcmobj, dm): @@ -206,6 +214,9 @@ def grad_solver(pcmobj, dm): dE = 0.5*v* d(K^-1 R) *v + q*dv v^T* d(K^-1 R)v = v^T*K^-1(dR - dK K^-1R)v = v^T K^-1(dR - dK q) ''' + mol = pcmobj.mol + log = logger.new_logger(mol, mol.verbose) + t1 = log.init_timer() if not pcmobj._intermediates or 'q_sym' not in pcmobj._intermediates: pcmobj._get_vind(dm) @@ -300,7 +311,7 @@ def contract_ket(a, B, c): de += de_dR - de_dK else: raise RuntimeError(f"Unknown implicit solvent model: {pcmobj.method}") - + t1 = log.timer_debug1('grad solver', *t1) return de.get() def make_grad_object(grad_method): diff --git a/gpu4pyscf/solvent/hessian/pcm.py b/gpu4pyscf/solvent/hessian/pcm.py index 4fe16261..ae28942c 100644 --- a/gpu4pyscf/solvent/hessian/pcm.py +++ b/gpu4pyscf/solvent/hessian/pcm.py @@ -132,7 +132,7 @@ def pcm_grad_scanner(mol): e, v = pcmobj._get_vind(dm) #return grad_elec(pcmobj, dm) return grad_nuc(pcmobj, dm) + grad_solver(pcmobj, dm) + grad_qv(pcmobj, dm) - + mol.verbose = 0 de = numpy.zeros([mol.natm, mol.natm, 3, 3]) eps = 1e-3 for ia in range(mol.natm): @@ -172,7 +172,8 @@ def pcm_vmat_scanner(mol): e, v = pcmobj._get_vind(dm) return v - vmat = cupy.zeros([len(atmlst), 3, nao, nocc]) + mol.verbose = 0 + vmat = cupy.empty([len(atmlst), 3, nao, nocc]) eps = 1e-3 for i0, ia in enumerate(atmlst): for ix in range(3): diff --git a/gpu4pyscf/solvent/pcm.py b/gpu4pyscf/solvent/pcm.py index 428e5a08..335de1f7 100644 --- a/gpu4pyscf/solvent/pcm.py +++ b/gpu4pyscf/solvent/pcm.py @@ -29,6 +29,7 @@ from gpu4pyscf.solvent import _attach_solvent from gpu4pyscf.df import int3c2e from gpu4pyscf.lib import logger +from gpu4pyscf.lib.cupy_helper import dist_matrix libdft = lib.load_library('libdft') @@ -191,7 +192,8 @@ def get_D_S(surface, with_S=True, with_D=False): xi_i, xi_j = cupy.meshgrid(charge_exp, charge_exp, indexing='ij') xi_ij = xi_i * xi_j / (xi_i**2 + xi_j**2)**0.5 #rij = scipy.spatial.distance.cdist(grid_coords, grid_coords) - rij = cupy.sum((grid_coords[:,None,:] - grid_coords[None,:,:])**2, axis=2)**0.5 + #rij = cupy.sum((grid_coords[:,None,:] - grid_coords[None,:,:])**2, axis=2)**0.5 + rij = dist_matrix(grid_coords) xi_r_ij = xi_ij * rij cupy.fill_diagonal(rij, 1) S = scipy.special.erf(xi_r_ij) / rij @@ -209,7 +211,7 @@ def get_D_S(surface, with_S=True, with_D=False): class PCM(ddcosmo.DDCOSMO): _keys = { - 'method', 'vdw_scale', 'surface' + 'method', 'vdw_scale', 'surface', 'r_probe', 'intopt' } def __init__(self, mol): ddcosmo.DDCOSMO.__init__(self, mol) From f61d4d5f9ac11bd6972e7c9c610cb2a98dd5f65f Mon Sep 17 00:00:00 2001 From: Xiaojie Wu Date: Tue, 16 Jan 2024 13:33:05 -0800 Subject: [PATCH 06/10] Dftdx (#83) * cmake workflow for dftd3 and dftd4 * updated cmake * new workflow for dftd3 and dftd4 * remove package_data * resolve dependencies in dftd3 and dftd4 * add dftd3 and dftd4 to __init__.py * updated the script for building wheels --- builder/build_dftdx.sh | 54 ------------ dockerfiles/manylinux/build_wheels.sh | 3 + gpu4pyscf/dft/rks.py | 22 ++--- gpu4pyscf/grad/rks.py | 22 ++--- gpu4pyscf/hessian/rks.py | 97 +++++++++++----------- gpu4pyscf/lib/CMakeLists.txt | 58 +++++++++++++ gpu4pyscf/lib/__init__.py | 12 ++- gpu4pyscf/lib/cupy_helper/take_last2d.cu | 2 +- gpu4pyscf/lib/dftd3.py | 100 +++++++++++++++++++++++ gpu4pyscf/lib/dftd4.py | 91 +++++++++++++++++++++ gpu4pyscf/lib/tests/test_cupy_helper.py | 2 +- gpu4pyscf/lib/tests/test_dftd3.py | 100 +++++++++++++++++++++++ gpu4pyscf/lib/tests/test_dftd4.py | 99 ++++++++++++++++++++++ setup.py | 34 +------- 14 files changed, 527 insertions(+), 169 deletions(-) delete mode 100644 builder/build_dftdx.sh create mode 100644 gpu4pyscf/lib/dftd3.py create mode 100644 gpu4pyscf/lib/dftd4.py create mode 100644 gpu4pyscf/lib/tests/test_dftd3.py create mode 100644 gpu4pyscf/lib/tests/test_dftd4.py diff --git a/builder/build_dftdx.sh b/builder/build_dftdx.sh deleted file mode 100644 index 05e43e3c..00000000 --- a/builder/build_dftdx.sh +++ /dev/null @@ -1,54 +0,0 @@ -#!/bin/bash - -WORK_DIR="./tmp" -rm -r ${WORK_DIR} -mkdir -p ${WORK_DIR} - -PROJECT_NAME=${PROJECT_NAME:-"dftd3"} - -SOURCE_URL=${SOURCE_URL:-"https://github.com/dftd3/simple-dftd3/releases/download/v1.0.0/dftd3-1.0.0-sdist.tar.gz"} - -TAR_GZ_NAME=$(basename ${SOURCE_URL}) - -BUILD_DIR="${WORK_DIR}/_build" -INSTALL_DIR="${WORK_DIR}/${PROJECT_NAME}-build" - -pip3 install meson ninja - -cd ${WORK_DIR} - -echo "Downloading source code from $SOURCE_URL..." -curl -L $SOURCE_URL -o $TAR_GZ_NAME - -echo "Extracting $TAR_GZ_NAME..." -tar -xzf $TAR_GZ_NAME - -SOURCE_DIR=$(tar -tf $TAR_GZ_NAME | head -1 | cut -f1 -d"/") -cd $SOURCE_DIR - -echo " -option( - 'openmp', - type: 'boolean', - value: false, - yield: true, - description: 'Use OpenMP parallelisation', -)" >> meson_options.txt - -echo "Setting up build system with meson..." -meson setup --wipe $BUILD_DIR -Dopenmp=false - -echo "Compiling the code..." -meson compile -C $BUILD_DIR - -echo "Configuring build system with prefix..." -meson configure $BUILD_DIR --prefix=$(realpath ${INSTALL_DIR}) - -echo "Installing to $INSTALL_DIR..." -meson install -C $BUILD_DIR - -echo "Installation complete." - -cd ../../ - -echo "All operations completed." \ No newline at end of file diff --git a/dockerfiles/manylinux/build_wheels.sh b/dockerfiles/manylinux/build_wheels.sh index b610e394..a4f444f9 100644 --- a/dockerfiles/manylinux/build_wheels.sh +++ b/dockerfiles/manylinux/build_wheels.sh @@ -14,6 +14,9 @@ export CUTENSOR_DIR=/usr/local/cuda export PATH=$CUDA_HOME/bin:$PATH export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH +# blas is required by DFTD3 and DFTD4 +yum install openblas-devel + # Compile wheels rm -rf /gpu4pyscf/wheelhouse for PYBIN in /opt/python/cp311-cp311/bin; do diff --git a/gpu4pyscf/dft/rks.py b/gpu4pyscf/dft/rks.py index 34ae52dd..c7a21eb6 100644 --- a/gpu4pyscf/dft/rks.py +++ b/gpu4pyscf/dft/rks.py @@ -248,23 +248,15 @@ def get_dispersion(self): return 0.0 if self.disp[:2].upper() == 'D3': - # multi-threads in DFTD3 conflicts with PyTorch, set it to be 1 for safty - from pyscf import lib - with lib.with_omp_threads(1): - import gpu4pyscf.dftd3.pyscf as disp - d3 = disp.DFTD3Dispersion(self.mol, xc=self.xc, version=self.disp) - e_d3, _ = d3.kernel() - return e_d3 + from gpu4pyscf.lib import dftd3 + dftd3_model = dftd3.DFTD3Dispersion(self.mol, xc=self.xc, version=self.disp) + res = dftd3_model.get_dispersion() + return res['energy'] if self.disp[:2].upper() == 'D4': - from pyscf.data.elements import charge - atoms = numpy.array([ charge(a[0]) for a in self.mol._atom]) - coords = self.mol.atom_coords() - from pyscf import lib - with lib.with_omp_threads(1): - from gpu4pyscf.dftd4.interface import DampingParam, DispersionModel - model = DispersionModel(atoms, coords) - res = model.get_dispersion(DampingParam(method=self.xc), grad=False) + from gpu4pyscf.lib import dftd4 + dftd4_model = dftd4.DFTD4Dispersion(self.mol, xc=self.xc) + res = dftd4_model.get_dispersion() return res.get("energy") def reset(self, mol=None): diff --git a/gpu4pyscf/grad/rks.py b/gpu4pyscf/grad/rks.py index 2bb22df6..08af545b 100644 --- a/gpu4pyscf/grad/rks.py +++ b/gpu4pyscf/grad/rks.py @@ -510,21 +510,13 @@ class Gradients(rhf_grad.Gradients, pyscf.grad.rks.Gradients): def get_dispersion(self): if self.base.disp[:2].upper() == 'D3': - from pyscf import lib - with lib.with_omp_threads(1): - import gpu4pyscf.dftd3.pyscf as disp - d3 = disp.DFTD3Dispersion(self.mol, xc=self.base.xc, version=self.base.disp) - _, g_d3 = d3.kernel() - return g_d3 + from gpu4pyscf.lib import dftd3 + dftd3_model = dftd3.DFTD3Dispersion(self.base.mol, xc=self.base.xc, version=self.base.disp) + res = dftd3_model.get_dispersion(grad=True) + return res['gradient'] if self.base.disp[:2].upper() == 'D4': - from pyscf.data.elements import charge - atoms = numpy.array([ charge(a[0]) for a in self.mol._atom]) - coords = self.mol.atom_coords() - - from pyscf import lib - with lib.with_omp_threads(1): - from gpu4pyscf.dftd4.interface import DampingParam, DispersionModel - model = DispersionModel(atoms, coords) - res = model.get_dispersion(DampingParam(method=self.base.xc), grad=True) + from gpu4pyscf.lib import dftd4 + dftd4_model = dftd4.DFTD4Dispersion(self.base.mol, xc=self.base.xc) + res = dftd4_model.get_dispersion(grad=True) return res.get("gradient") \ No newline at end of file diff --git a/gpu4pyscf/hessian/rks.py b/gpu4pyscf/hessian/rks.py index f808ecde..b8ba5da3 100644 --- a/gpu4pyscf/hessian/rks.py +++ b/gpu4pyscf/hessian/rks.py @@ -688,60 +688,59 @@ def to_cpu(self): def get_dispersion(self): if self.base.disp[:2].upper() == 'D3': - from pyscf import lib - with lib.with_omp_threads(1): - import gpu4pyscf.dftd3.pyscf as disp - coords = self.mol.atom_coords() - natm = self.mol.natm - h_d3 = numpy.zeros([self.mol.natm, self.mol.natm, 3,3]) - mol = self.mol.copy() - eps = 1e-5 - for i in range(natm): - for j in range(3): - coords[i,j] += eps - mol.set_geom_(coords, unit='Bohr') - d3 = disp.DFTD3Dispersion(mol, xc=self.base.xc, version=self.base.disp) - _, g1 = d3.kernel() - - coords[i,j] -= 2.0*eps - mol.set_geom_(coords, unit='Bohr') - d3 = disp.DFTD3Dispersion(mol, xc=self.base.xc, version=self.base.disp) - _, g2 = d3.kernel() - - coords[i,j] += eps - h_d3[i,:,j,:] = (g1 - g2)/(2.0*eps) + from gpu4pyscf.lib import dftd3 + coords = self.mol.atom_coords() + natm = self.mol.natm + h_d3 = numpy.zeros([self.mol.natm, self.mol.natm, 3,3]) + mol = self.mol.copy() + eps = 1e-5 + for i in range(natm): + for j in range(3): + coords[i,j] += eps + mol.set_geom_(coords, unit='Bohr') + mol.build() + dftd3_model = dftd3.DFTD3Dispersion(mol, xc=self.base.xc, version=self.base.disp) + res = dftd3_model.get_dispersion(grad=True) + g1 = res['gradient'] + + coords[i,j] -= 2.0*eps + mol.set_geom_(coords, unit='Bohr') + mol.build() + dftd3_model = dftd3.DFTD3Dispersion(mol, xc=self.base.xc, version=self.base.disp) + res = dftd3_model.get_dispersion(grad=True) + g2 = res['gradient'] + + coords[i,j] += eps + h_d3[i,:,j,:] = (g1 - g2)/(2.0*eps) return h_d3 if self.base.disp[:2].upper() == 'D4': - from pyscf.data.elements import charge - atoms = numpy.array([ charge(a[0]) for a in self.mol._atom]) + from gpu4pyscf.lib import dftd4 coords = self.mol.atom_coords() natm = self.mol.natm - from pyscf import lib - with lib.with_omp_threads(1): - from gpu4pyscf.dftd4.interface import DampingParam, DispersionModel - params = DampingParam(method=self.base.xc) - mol = self.mol.copy() - h_d3 = numpy.zeros([self.mol.natm, self.mol.natm, 3,3]) - eps = 1e-5 - for i in range(natm): - for j in range(3): - coords[i,j] += eps - mol.set_geom_(coords, unit='Bohr') - model = DispersionModel(atoms, coords) - res = model.get_dispersion(params, grad=True) - g1 = res.get("gradient") - - coords[i,j] -= 2.0*eps - mol.set_geom_(coords, unit='Bohr') - model = DispersionModel(atoms, coords) - res = model.get_dispersion(params, grad=True) - g2 = res.get("gradient") - - coords[i,j] += eps - h_d3[i,:,j,:] = (g1 - g2)/(2.0*eps) - - return h_d3 + mol = self.mol.copy() + h_d4 = numpy.zeros([mol.natm, mol.natm, 3,3]) + eps = 1e-5 + for i in range(natm): + for j in range(3): + coords[i,j] += eps + mol.set_geom_(coords, unit='Bohr') + mol.build() + dftd4_model = dftd4.DFTD4Dispersion(mol, xc=self.base.xc) + res = dftd4_model.get_dispersion(grad=True) + g1 = res.get("gradient") + + coords[i,j] -= 2.0*eps + mol.set_geom_(coords, unit='Bohr') + mol.build() + dftd4_model = dftd4.DFTD4Dispersion(mol, xc=self.base.xc) + res = dftd4_model.get_dispersion(grad=True) + g2 = res.get("gradient") + + coords[i,j] += eps + h_d4[i,:,j,:] = (g1 - g2)/(2.0*eps) + + return h_d4 partial_hess_elec = partial_hess_elec make_h1 = make_h1 diff --git a/gpu4pyscf/lib/CMakeLists.txt b/gpu4pyscf/lib/CMakeLists.txt index f9e74f42..cd7310c9 100644 --- a/gpu4pyscf/lib/CMakeLists.txt +++ b/gpu4pyscf/lib/CMakeLists.txt @@ -115,3 +115,61 @@ if(BUILD_LIBXC) CMAKE_CACHE_ARGS -DCMAKE_CUDA_ARCHITECTURES:STRING=${CMAKE_CUDA_ARCHITECTURES} ) endif() + +# ---- compilation for dftd3 and dft4 +# 1. build static dependencies +# 2. build dftd3 and dftd4 shared libs, dftd3 and dftd4 will automatically search their dependencies +# https://github.com/dftd4/dftd4/blob/3fc00439c6abea2639868b644c52f0920d6c2e22/config/cmake/Findmstore.cmake#L24 +option(BUILD_DFTD3 "Using DFTD3 for DFT" ON) +if(BUILD_DFTD3) + include(ExternalProject) + ExternalProject_Add(dftd3_static + GIT_REPOSITORY "https://github.com/dftd3/simple-dftd3" + GIT_TAG v1.0.0 + PREFIX ${PROJECT_BINARY_DIR}/deps + INSTALL_DIR ${PROJECT_SOURCE_DIR}/deps + CMAKE_ARGS -DWITH_OpenMP=OFF + -DCMAKE_INSTALL_PREFIX:PATH= + -DCMAKE_INSTALL_LIBDIR:PATH=lib + -DCMAKE_CURRENT_SOURCE_DIR:PATH=dftd3 + ) + + include(ExternalProject) + ExternalProject_Add(dftd3 + GIT_REPOSITORY "https://github.com/dftd3/simple-dftd3" + GIT_TAG v1.0.0 + PREFIX ${PROJECT_BINARY_DIR}/deps + INSTALL_DIR ${PROJECT_SOURCE_DIR}/deps + CMAKE_ARGS -DWITH_OpenMP=OFF -DBUILD_SHARED_LIBS=ON + -DCMAKE_INSTALL_PREFIX:PATH= + -DCMAKE_INSTALL_LIBDIR:PATH=lib + -DCMAKE_CURRENT_SOURCE_DIR:PATH=dftd3 + ) + add_dependencies(dftd3 dftd3_static) +endif() + +option(BUILD_DFTD4 "Using DFTD4 for DFT" ON) +if(BUILD_DFTD4) + include(ExternalProject) + ExternalProject_Add(dftd4_static + GIT_REPOSITORY "https://github.com/dftd4/dftd4" + GIT_TAG v3.6.0 + PREFIX ${PROJECT_BINARY_DIR}/deps + INSTALL_DIR ${PROJECT_SOURCE_DIR}/deps + CMAKE_ARGS -DWITH_OpenMP=OFF + -DCMAKE_INSTALL_PREFIX:PATH= + -DCMAKE_INSTALL_LIBDIR:PATH=lib + ) + + include(ExternalProject) + ExternalProject_Add(dftd4 + GIT_REPOSITORY "https://github.com/dftd4/dftd4" + GIT_TAG v3.6.0 + PREFIX ${PROJECT_BINARY_DIR}/deps + INSTALL_DIR ${PROJECT_SOURCE_DIR}/deps + CMAKE_ARGS -DWITH_OpenMP=OFF -DBUILD_SHARED_LIBS=ON + -DCMAKE_INSTALL_PREFIX:PATH= + -DCMAKE_INSTALL_LIBDIR:PATH=lib + ) + add_dependencies(dftd4 dftd4_static) +endif() \ No newline at end of file diff --git a/gpu4pyscf/lib/__init__.py b/gpu4pyscf/lib/__init__.py index 147324d9..e78e0087 100644 --- a/gpu4pyscf/lib/__init__.py +++ b/gpu4pyscf/lib/__init__.py @@ -17,4 +17,14 @@ import numpy from gpu4pyscf.lib import diis from gpu4pyscf.lib import cupy_helper -from gpu4pyscf.lib import cutensor \ No newline at end of file +from gpu4pyscf.lib import cutensor + +try: + from gpu4pyscf.lib import dftd3 +except Exception: + print('failed to load DFTD3') + +try: + from gpu4pyscf.lib import dftd4 +except Exception: + print('failed to load DFTD4') diff --git a/gpu4pyscf/lib/cupy_helper/take_last2d.cu b/gpu4pyscf/lib/cupy_helper/take_last2d.cu index 4b671211..d24f3163 100644 --- a/gpu4pyscf/lib/cupy_helper/take_last2d.cu +++ b/gpu4pyscf/lib/cupy_helper/take_last2d.cu @@ -42,7 +42,7 @@ static void _takebak(double *out, double *a, int *indices, { int i0 = blockIdx.y * COUNT_BLOCK; int j = blockIdx.x * blockDim.x + threadIdx.x; - if (j > n_a) { + if (j >= n_a) { return; } diff --git a/gpu4pyscf/lib/dftd3.py b/gpu4pyscf/lib/dftd3.py new file mode 100644 index 00000000..6c28bee3 --- /dev/null +++ b/gpu4pyscf/lib/dftd3.py @@ -0,0 +1,100 @@ +# Copyright 2023 The GPU4PySCF Authors. All Rights Reserved. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import os +import numpy as np +import ctypes +from pyscf import lib, gto + +libdftd3 = np.ctypeslib.load_library('libs-dftd3', os.path.abspath(os.path.join(__file__, '..', 'deps', 'lib'))) + +_load_damping_param = { + "d3bj": libdftd3.dftd3_load_rational_damping, #RationalDampingParam, + "d3zero": libdftd3.dftd3_load_zero_damping, #ZeroDampingParam, + "d3bjm": libdftd3.dftd3_load_mrational_damping, #ModifiedRationalDampingParam, + "d3mbj": libdftd3.dftd3_load_mrational_damping, #ModifiedRationalDampingParam, + "d3zerom":libdftd3.dftd3_load_mzero_damping, #ModifiedZeroDampingParam, + "d3mzero":libdftd3.dftd3_load_mzero_damping, #ModifiedZeroDampingParam, + "d3op": libdftd3.dftd3_load_optimizedpower_damping #OptimizedPowerDampingParam, +} + + +libdftd3.dftd3_new_error.restype = ctypes.c_void_p +libdftd3.dftd3_new_structure.restype = ctypes.c_void_p + +class DFTD3Dispersion(lib.StreamObject): + def __init__(self, mol, xc, version='d3bj', atm=False): + coords = np.asarray(mol.atom_coords(), dtype=np.double, order='C') + nuc_types = [gto.charge(mol.atom_symbol(ia)) + for ia in range(mol.natm)] + nuc_types = np.asarray(nuc_types, dtype=np.int32) + self.natm = mol.natm + self._lattice = lib.c_null_ptr() + self._periodic = lib.c_null_ptr() + + err = libdftd3.dftd3_new_error() + self._mol = libdftd3.dftd3_new_structure( + err, + ctypes.c_int(mol.natm), + nuc_types.ctypes.data_as(ctypes.c_void_p), + coords.ctypes.data_as(ctypes.c_void_p), + self._lattice, + self._periodic, + ) + + self._disp = libdftd3.dftd3_new_d3_model(err, self._mol) + self._param = _load_damping_param[version]( + err, + ctypes.create_string_buffer(xc.encode(), size=50), + ctypes.c_bool(atm)) + libdftd3.dftd3_delete_error(err) + + def __del__(self): + err = libdftd3.dftd3_new_error() + libdftd3.dftd3_delete_structure(err, self._mol) + libdftd3.dftd3_delete_param(err, self._param) + libdftd3.dftd3_delete_model(err, self._disp) + libdftd3.dftd3_delete_error(err) + + def get_dispersion(self, grad=False): + res = {} + _energy = np.array(0.0, dtype=np.double) + if grad: + _gradient = np.zeros((self.natm,3)) + _sigma = np.zeros((3,3)) + _gradient_str = _gradient.ctypes.data_as(ctypes.c_void_p) + _sigma_str = _sigma.ctypes.data_as(ctypes.c_void_p) + else: + _gradient = None + _sigma = None + _gradient_str = lib.c_null_ptr() + _sigma_str = lib.c_null_ptr() + + err = libdftd3.dftd3_new_error() + libdftd3.dftd3_get_dispersion( + err, + self._mol, + self._disp, + self._param, + _energy.ctypes.data_as(ctypes.c_void_p), + _gradient_str, + _sigma_str) + res = dict(energy=_energy) + if _gradient is not None: + res.update(gradient=_gradient) + if _sigma is not None: + res.update(virial=_sigma) + libdftd3.dftd3_delete_error(err) + return res \ No newline at end of file diff --git a/gpu4pyscf/lib/dftd4.py b/gpu4pyscf/lib/dftd4.py new file mode 100644 index 00000000..ae5adc05 --- /dev/null +++ b/gpu4pyscf/lib/dftd4.py @@ -0,0 +1,91 @@ +# Copyright 2023 The GPU4PySCF Authors. All Rights Reserved. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import os +import numpy as np +import ctypes +from pyscf import lib, gto + +libdftd4 = np.ctypeslib.load_library('libdftd4', os.path.abspath(os.path.join(__file__, '..', 'deps', 'lib'))) + +libdftd4.dftd4_new_error.restype = ctypes.c_void_p +libdftd4.dftd4_new_structure.restype = ctypes.c_void_p + +class DFTD4Dispersion(lib.StreamObject): + def __init__(self, mol, xc, atm=False): + coords = np.asarray(mol.atom_coords(), dtype=np.double, order='C') + charges = np.asarray(mol.atom_charges(), dtype=np.int32) + nuc_types = [gto.charge(mol.atom_symbol(ia)) + for ia in range(mol.natm)] + nuc_types = np.asarray(nuc_types, dtype=np.int32) + self.natm = mol.natm + self._lattice = lib.c_null_ptr() + self._periodic = lib.c_null_ptr() + + err = libdftd4.dftd4_new_error() + self._mol = libdftd4.dftd4_new_structure( + err, + ctypes.c_int(mol.natm), + nuc_types.ctypes.data_as(ctypes.c_void_p), + coords.ctypes.data_as(ctypes.c_void_p), + charges.ctypes.data_as(ctypes.c_void_p), + self._lattice, + self._periodic, + ) + + self._disp = libdftd4.dftd4_new_d4_model(err, self._mol) + self._param = libdftd4.dftd4_load_rational_damping( + err, + ctypes.create_string_buffer(xc.encode(), size=50), + ctypes.c_bool(atm)) + libdftd4.dftd4_delete_error(err) + + def __del__(self): + err = libdftd4.dftd4_new_error() + libdftd4.dftd4_delete_structure(err, self._mol) + libdftd4.dftd4_delete_param(err, self._param) + libdftd4.dftd4_delete_model(err, self._disp) + libdftd4.dftd4_delete_error(err) + + def get_dispersion(self, grad=False): + res = {} + _energy = np.array(0.0, dtype=np.double) + if grad: + _gradient = np.zeros((self.natm,3)) + _sigma = np.zeros((3,3)) + _gradient_str = _gradient.ctypes.data_as(ctypes.c_void_p) + _sigma_str = _sigma.ctypes.data_as(ctypes.c_void_p) + else: + _gradient = None + _sigma = None + _gradient_str = lib.c_null_ptr() + _sigma_str = lib.c_null_ptr() + + err = libdftd4.dftd4_new_error() + libdftd4.dftd4_get_dispersion( + err, + self._mol, + self._disp, + self._param, + _energy.ctypes.data_as(ctypes.c_void_p), + _gradient_str, + _sigma_str) + res = dict(energy=_energy) + if _gradient is not None: + res.update(gradient=_gradient) + if _sigma is not None: + res.update(virial=_sigma) + libdftd4.dftd4_delete_error(err) + return res \ No newline at end of file diff --git a/gpu4pyscf/lib/tests/test_cupy_helper.py b/gpu4pyscf/lib/tests/test_cupy_helper.py index befbf827..2b085674 100644 --- a/gpu4pyscf/lib/tests/test_cupy_helper.py +++ b/gpu4pyscf/lib/tests/test_cupy_helper.py @@ -31,7 +31,7 @@ def test_take_last2d(self): assert(cupy.linalg.norm(a[:,indices][:,:,indices] - b) < 1e-10) def test_transpose_sum(self): - n = 1287 + n = 31 count = 127 a = cupy.random.rand(count,n,n) b = a + a.transpose(0,2,1) diff --git a/gpu4pyscf/lib/tests/test_dftd3.py b/gpu4pyscf/lib/tests/test_dftd3.py new file mode 100644 index 00000000..c88b020b --- /dev/null +++ b/gpu4pyscf/lib/tests/test_dftd3.py @@ -0,0 +1,100 @@ +# Copyright 2023 The GPU4PySCF Authors. All Rights Reserved. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import unittest +import numpy as np +from pyscf import gto +from gpu4pyscf.lib import dftd3 + +class KnownValues(unittest.TestCase): + def test_energy_r2scan_d3(self): + mol = gto.M( + atom=''' + C -0.755422531 -0.796459123 -1.023590391 + C 0.634274834 -0.880017014 -1.075233285 + C 1.406955202 0.199695367 -0.653144334 + C 0.798863737 1.361204515 -0.180597909 + C -0.593166787 1.434312023 -0.133597923 + C -1.376239198 0.359205222 -0.553258516 + I -1.514344238 3.173268101 0.573601106 + H 1.110906949 -1.778801728 -1.440619836 + H 1.399172302 2.197767355 0.147412751 + H 2.486417780 0.142466525 -0.689380574 + H -2.454252250 0.422581120 -0.512807958 + H -1.362353593 -1.630564523 -1.348743149 + S -3.112683203 6.289227834 1.226984439 + H -4.328789697 5.797771251 0.973373089 + C -2.689135032 6.703163830 -0.489062886 + H -1.684433029 7.115457372 -0.460265708 + H -2.683867206 5.816530502 -1.115183775 + H -3.365330613 7.451201412 -0.890098894 + ''') + + dftd3_model = dftd3.DFTD3Dispersion(mol, "r2SCAN", atm=True) + res = dftd3_model.get_dispersion() + assert np.allclose(res['energy'], -0.005790963570050724) + + dftd3_model = dftd3.DFTD3Dispersion(mol, "r2SCAN") + res = dftd3_model.get_dispersion() + assert np.allclose(res['energy'], -0.005784012374055654) + + def test_gradient_r2scan_d3(self): + mol = gto.M( + atom=''' + H 0.002144194 0.361043475 0.029799709 + C 0.015020592 0.274789738 1.107648016 + C 1.227632658 0.296655040 1.794629427 + C 1.243958826 0.183702791 3.183703934 + C 0.047958213 0.048915002 3.886484583 + C -1.165135654 0.026954348 3.200213281 + C -1.181832083 0.139828643 1.810376587 + H 2.155807907 0.399177037 1.249441585 + H 2.184979344 0.198598553 3.716170761 + H 0.060934662 -0.040672756 4.964014252 + H -2.093220602 -0.078628959 3.745125056 + H -2.122845437 0.123257119 1.277645797 + Br -0.268325907 -3.194209024 1.994458950 + C 0.049999933 -5.089197474 1.929391171 + F 0.078949601 -5.512441335 0.671851563 + F 1.211983937 -5.383996300 2.498664481 + F -0.909987405 -5.743747328 2.570721738 + ''') + + ref = np.array([ + [+7.13721248e-07, +2.19571763e-05, -3.77372946e-05], + [+9.19838860e-07, +3.53459763e-05, -1.43306994e-06], + [+7.43860881e-06, +3.78237447e-05, +8.46031238e-07], + [+8.06120927e-06, +3.79834948e-05, +8.58427570e-06], + [+1.16592466e-06, +3.62585085e-05, +1.16326308e-05], + [-3.69381337e-06, +3.39047971e-05, +6.92483428e-06], + [-3.05404225e-06, +3.29484247e-05, +1.80766271e-06], + [+3.51228183e-05, +2.08136972e-05, -1.76546837e-05], + [+3.49762054e-05, +1.66544908e-05, +2.14435772e-05], + [+1.57516340e-06, +1.41373959e-05, +4.21574793e-05], + [-3.35392428e-05, +1.49030766e-05, +2.29976305e-05], + [-3.38817253e-05, +1.82002569e-05, -1.72487448e-05], + [-2.15610724e-05, -1.87935101e-04, -3.02815495e-05], + [+1.27580963e-06, -5.96841724e-05, -5.99713166e-06], + [+9.01173808e-07, -2.23010304e-05, -7.96228701e-06], + [+7.42062176e-06, -2.79631452e-05, +7.03703317e-07], + [-3.84119900e-06, -2.30475903e-05, +1.21693625e-06],] + ) + dftd3_model = dftd3.DFTD3Dispersion(mol, "r2SCAN", atm=False) + res = dftd3_model.get_dispersion(grad=True) + assert np.linalg.norm(ref - res['gradient']) < 1e-10 + +if __name__ == "__main__": + print("Full tests for DFTD3 module") + unittest.main() \ No newline at end of file diff --git a/gpu4pyscf/lib/tests/test_dftd4.py b/gpu4pyscf/lib/tests/test_dftd4.py new file mode 100644 index 00000000..a8f8737c --- /dev/null +++ b/gpu4pyscf/lib/tests/test_dftd4.py @@ -0,0 +1,99 @@ +# Copyright 2023 The GPU4PySCF Authors. All Rights Reserved. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import unittest +import numpy as np +from pyscf import gto +from gpu4pyscf.lib import dftd4 + +class KnownValues(unittest.TestCase): + def test_energy_r2scan_d4(self): + mol = gto.M( + atom=''' + C -0.755422531 -0.796459123 -1.023590391 + C 0.634274834 -0.880017014 -1.075233285 + C 1.406955202 0.199695367 -0.653144334 + C 0.798863737 1.361204515 -0.180597909 + C -0.593166787 1.434312023 -0.133597923 + C -1.376239198 0.359205222 -0.553258516 + I -1.514344238 3.173268101 0.573601106 + H 1.110906949 -1.778801728 -1.440619836 + H 1.399172302 2.197767355 0.147412751 + H 2.486417780 0.142466525 -0.689380574 + H -2.454252250 0.422581120 -0.512807958 + H -1.362353593 -1.630564523 -1.348743149 + S -3.112683203 6.289227834 1.226984439 + H -4.328789697 5.797771251 0.973373089 + C -2.689135032 6.703163830 -0.489062886 + H -1.684433029 7.115457372 -0.460265708 + H -2.683867206 5.816530502 -1.115183775 + H -3.365330613 7.451201412 -0.890098894 + ''') + + dftd4_model = dftd4.DFTD4Dispersion(mol, xc="r2SCAN", atm=True) + res = dftd4_model.get_dispersion() + assert np.allclose(res['energy'], -0.005001101058518388) + + dftd4_model = dftd4.DFTD4Dispersion(mol, xc="r2SCAN") + res = dftd4_model.get_dispersion() + assert np.allclose(res['energy'], -0.005001101058518388) + + def test_gradient_r2scan_d4(self): + mol = gto.M( + atom=''' + H 0.002144194 0.361043475 0.029799709 + C 0.015020592 0.274789738 1.107648016 + C 1.227632658 0.296655040 1.794629427 + C 1.243958826 0.183702791 3.183703934 + C 0.047958213 0.048915002 3.886484583 + C -1.165135654 0.026954348 3.200213281 + C -1.181832083 0.139828643 1.810376587 + H 2.155807907 0.399177037 1.249441585 + H 2.184979344 0.198598553 3.716170761 + H 0.060934662 -0.040672756 4.964014252 + H -2.093220602 -0.078628959 3.745125056 + H -2.122845437 0.123257119 1.277645797 + Br -0.268325907 -3.194209024 1.994458950 + C 0.049999933 -5.089197474 1.929391171 + F 0.078949601 -5.512441335 0.671851563 + F 1.211983937 -5.383996300 2.498664481 + F -0.909987405 -5.743747328 2.570721738 + ''') + + ref = np.array([ + [+6.02987248e-07, +1.18181692e-05, -2.11659178e-05], + [+3.77083487e-07, +4.21255367e-05, -3.65576556e-05], + [+3.71749233e-05, +4.38986750e-05, -1.64037320e-05], + [+3.79004788e-05, +4.09262181e-05, +2.57427629e-05], + [+1.49281462e-06, +3.63132380e-05, +4.66732244e-05], + [-3.45592945e-05, +3.46256250e-05, +2.53829747e-05], + [-3.48859913e-05, +3.74107269e-05, -1.56473785e-05], + [+2.00543104e-05, +1.15042699e-05, -9.90469697e-06], + [+1.99879228e-05, +9.25641402e-06, +1.21976769e-05], + [+1.10396127e-06, +7.69249859e-06, +2.38607706e-05], + [-1.86258815e-05, +7.79467748e-06, +1.29284817e-05], + [-1.87883833e-05, +9.46661745e-06, -9.65731010e-06], + [-2.38952311e-05, -1.10356928e-04, -2.28127181e-05], + [+4.05848507e-07, -5.94239995e-05, -6.36138164e-06], + [+2.78030538e-06, -3.80326610e-05, -1.91595254e-05], + [+1.91553258e-05, -4.44033682e-05, +4.86234846e-06], + [-1.02811799e-05, -4.06157099e-05, +6.02207637e-06],] + ) + dftd4_model = dftd4.DFTD4Dispersion(mol, "r2SCAN", atm=False) + res = dftd4_model.get_dispersion(grad=True) + assert np.linalg.norm(ref - res['gradient']) < 1e-10 +if __name__ == "__main__": + print("Full tests for DFTD4 module") + unittest.main() \ No newline at end of file diff --git a/setup.py b/setup.py index 3c842953..8ea88dd2 100755 --- a/setup.py +++ b/setup.py @@ -22,7 +22,6 @@ import subprocess import re import glob -import subprocess from setuptools import setup, find_packages, Extension from setuptools.command.build_py import build_py @@ -87,33 +86,8 @@ def run(self): else: self.spawn(cmd) - self.build_dftd('dftd3', 'https://github.com/dftd3/simple-dftd3/releases/download/v1.0.0/dftd3-1.0.0-sdist.tar.gz') - self.build_dftd('dftd4', 'https://github.com/dftd4/dftd4/releases/download/v3.6.0/dftd4-sdist-3.6.0.tar.gz') - super().run() - def build_dftd(self,project_name,source_url): - self.plat_name = get_platform() - self.build_base = 'build' - self.build_lib = os.path.join(self.build_base, 'lib') - self.build_temp = os.path.join(self.build_base, f'temp.{self.plat_name}') - - script_path = 'builder/build_dftdx.sh' - if not os.path.exists(script_path): - raise FileNotFoundError("Cannot find build script: {}".format(script_path)) - - subprocess.run(f"PROJECT_NAME={project_name} SOURCE_URL={source_url} sh {script_path}", shell=True, check=True) - - build_dir_pattern = f'tmp/{project_name}-*/tmp/{project_name}-build/lib/python3/dist-packages/{project_name}' - build_dirs = glob.glob(build_dir_pattern) - if not len(build_dirs) == 1: - raise FileNotFoundError("Cannot find build directory: {}".format(build_dir_pattern)) - build_dir = build_dirs[0] - - target_dir = os.path.join(self.build_lib, 'gpu4pyscf', project_name) - self.copy_tree(build_dir, target_dir) - - # build_py will produce plat_name = 'any'. Patch the bdist_wheel to change the # platform tag because the C extensions are platform dependent. from wheel.bdist_wheel import bdist_wheel @@ -152,13 +126,7 @@ def initialize_with_default_plat_name(self): install_requires=[ 'pyscf>=2.4.0', f'cupy-cuda{CUDA_VERSION}>=12.0', - # 'dftd3==0.7.0', - # 'dftd4==3.5.0', 'geometric', f'gpu4pyscf-libxc-cuda{CUDA_VERSION}', - ], - package_data={ - "gpu4pyscf.dftd3": ["_libdftd3*.so", "parameters.toml"], - "gpu4pyscf.dftd4": ["_libdftd4*.so", "*.toml", "*.json"], - }, + ] ) From 523dbf147a307c729754379dfe097fd057c14e2d Mon Sep 17 00:00:00 2001 From: Xiaojie Wu Date: Tue, 16 Jan 2024 13:38:15 -0800 Subject: [PATCH 07/10] Update build_wheels.sh --- dockerfiles/manylinux/build_wheels.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dockerfiles/manylinux/build_wheels.sh b/dockerfiles/manylinux/build_wheels.sh index a4f444f9..e95df9f7 100644 --- a/dockerfiles/manylinux/build_wheels.sh +++ b/dockerfiles/manylinux/build_wheels.sh @@ -15,7 +15,7 @@ export PATH=$CUDA_HOME/bin:$PATH export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH # blas is required by DFTD3 and DFTD4 -yum install openblas-devel +yum install -y openblas-devel # Compile wheels rm -rf /gpu4pyscf/wheelhouse From 5f22de4638b7518275b0f2b66851421ca8b16cab Mon Sep 17 00:00:00 2001 From: Xiaojie Wu Date: Tue, 16 Jan 2024 16:51:07 -0800 Subject: [PATCH 08/10] bugfix:v0.6.16 (#84) * cmake workflow for dftd3 and dftd4 * updated cmake * new workflow for dftd3 and dftd4 * remove package_data * resolve dependencies in dftd3 and dftd4 * add dftd3 and dftd4 to __init__.py * updated the script for building wheels * memory leak in dftd3 and dft4 --- gpu4pyscf/__init__.py | 2 +- gpu4pyscf/lib/dftd3.py | 14 ++++++++++---- gpu4pyscf/lib/dftd4.py | 9 ++++++--- gpu4pyscf/lib/tests/test_dftd4.py | 1 + gpu4pyscf/qmmm/test/test_chelpg.py | 6 +++--- 5 files changed, 21 insertions(+), 11 deletions(-) diff --git a/gpu4pyscf/__init__.py b/gpu4pyscf/__init__.py index 12273fed..11890d28 100644 --- a/gpu4pyscf/__init__.py +++ b/gpu4pyscf/__init__.py @@ -1,5 +1,5 @@ from . import lib, grad, hessian, solvent, scf, dft -__version__ = '0.6.15' +__version__ = '0.6.16' # monkey patch libxc reference due to a bug in nvcc from pyscf.dft import libxc diff --git a/gpu4pyscf/lib/dftd3.py b/gpu4pyscf/lib/dftd3.py index 6c28bee3..85e38eae 100644 --- a/gpu4pyscf/lib/dftd3.py +++ b/gpu4pyscf/lib/dftd3.py @@ -30,9 +30,14 @@ "d3op": libdftd3.dftd3_load_optimizedpower_damping #OptimizedPowerDampingParam, } - -libdftd3.dftd3_new_error.restype = ctypes.c_void_p -libdftd3.dftd3_new_structure.restype = ctypes.c_void_p +libdftd3.dftd3_new_error.restype = ctypes.c_void_p +libdftd3.dftd3_new_structure.restype = ctypes.c_void_p +libdftd3.dftd3_load_optimizedpower_damping.restype = ctypes.c_void_p +libdftd3.dftd3_load_mzero_damping.restype = ctypes.c_void_p +libdftd3.dftd3_load_mrational_damping.restype = ctypes.c_void_p +libdftd3.dftd3_load_zero_damping.restype = ctypes.c_void_p +libdftd3.dftd3_load_rational_damping.restype = ctypes.c_void_p +libdftd3.dftd3_new_d3_model.restype = ctypes.c_void_p class DFTD3Dispersion(lib.StreamObject): def __init__(self, mol, xc, version='d3bj', atm=False): @@ -63,8 +68,9 @@ def __init__(self, mol, xc, version='d3bj', atm=False): def __del__(self): err = libdftd3.dftd3_new_error() + param = ctypes.cast(self._param, ctypes.c_void_p) + libdftd3.dftd3_delete_param(ctypes.byref(param)) libdftd3.dftd3_delete_structure(err, self._mol) - libdftd3.dftd3_delete_param(err, self._param) libdftd3.dftd3_delete_model(err, self._disp) libdftd3.dftd3_delete_error(err) diff --git a/gpu4pyscf/lib/dftd4.py b/gpu4pyscf/lib/dftd4.py index ae5adc05..47defa15 100644 --- a/gpu4pyscf/lib/dftd4.py +++ b/gpu4pyscf/lib/dftd4.py @@ -20,8 +20,10 @@ libdftd4 = np.ctypeslib.load_library('libdftd4', os.path.abspath(os.path.join(__file__, '..', 'deps', 'lib'))) -libdftd4.dftd4_new_error.restype = ctypes.c_void_p -libdftd4.dftd4_new_structure.restype = ctypes.c_void_p +libdftd4.dftd4_new_error.restype = ctypes.c_void_p +libdftd4.dftd4_new_structure.restype = ctypes.c_void_p +libdftd4.dftd4_new_d4_model.restype = ctypes.c_void_p +libdftd4.dftd4_load_rational_damping.restype = ctypes.c_void_p class DFTD4Dispersion(lib.StreamObject): def __init__(self, mol, xc, atm=False): @@ -54,8 +56,9 @@ def __init__(self, mol, xc, atm=False): def __del__(self): err = libdftd4.dftd4_new_error() + param = ctypes.cast(self._param, ctypes.c_void_p) + libdftd4.dftd4_delete_param(ctypes.byref(param)) libdftd4.dftd4_delete_structure(err, self._mol) - libdftd4.dftd4_delete_param(err, self._param) libdftd4.dftd4_delete_model(err, self._disp) libdftd4.dftd4_delete_error(err) diff --git a/gpu4pyscf/lib/tests/test_dftd4.py b/gpu4pyscf/lib/tests/test_dftd4.py index a8f8737c..0e0f4f03 100644 --- a/gpu4pyscf/lib/tests/test_dftd4.py +++ b/gpu4pyscf/lib/tests/test_dftd4.py @@ -94,6 +94,7 @@ def test_gradient_r2scan_d4(self): dftd4_model = dftd4.DFTD4Dispersion(mol, "r2SCAN", atm=False) res = dftd4_model.get_dispersion(grad=True) assert np.linalg.norm(ref - res['gradient']) < 1e-10 + if __name__ == "__main__": print("Full tests for DFTD4 module") unittest.main() \ No newline at end of file diff --git a/gpu4pyscf/qmmm/test/test_chelpg.py b/gpu4pyscf/qmmm/test/test_chelpg.py index 836ce15e..5154e35b 100644 --- a/gpu4pyscf/qmmm/test/test_chelpg.py +++ b/gpu4pyscf/qmmm/test/test_chelpg.py @@ -49,7 +49,7 @@ def run_dft_chelpg(xc, deltaR): e_dft = mf.kernel() q = chelpg.eval_chelpg_layer_gpu(mf, deltaR=deltaR) return e_dft, q - + class KnownValues(unittest.TestCase): ''' @@ -63,7 +63,7 @@ class KnownValues(unittest.TestCase): CHELPG TRUE SCF_CONVERGENCE 10 $end - + Ground-State ChElPG Net Atomic Charges Atom Charge (a.u.) @@ -78,7 +78,7 @@ def test_rks_b3lyp(self): e_tot, q = run_dft_chelpg('B3LYP', 0.1) assert np.allclose(e_tot, -76.4666495181) assert np.allclose(q, np.array([-0.712558, 0.356292, 0.356266])) - + if __name__ == "__main__": print("Full Tests for SCF") From a63caf09f185ed18678ce8a8a5fa6a8d72fd1b7d Mon Sep 17 00:00:00 2001 From: Xiaojie Wu Date: Thu, 18 Jan 2024 00:13:34 -0800 Subject: [PATCH 09/10] V0.6.17 (#85) * cmake workflow for dftd3 and dftd4 * updated cmake * new workflow for dftd3 and dftd4 * remove package_data * resolve dependencies in dftd3 and dftd4 * add dftd3 and dftd4 to __init__.py * updated the script for building wheels * memory leak in dftd3 and dft4 * fixed memory leak in dftd3, debug for A800 * remove comments --- gpu4pyscf/__config__.py | 2 +- gpu4pyscf/__init__.py | 3 ++- gpu4pyscf/df/df.py | 7 ++++--- gpu4pyscf/df/grad/rhf.py | 2 +- gpu4pyscf/lib/cupy_helper.py | 6 ++++-- gpu4pyscf/lib/cutensor.py | 2 +- gpu4pyscf/lib/dftd3.py | 37 +++++++++++++++++++++--------------- gpu4pyscf/lib/dftd4.py | 29 +++++++++++++++++----------- gpu4pyscf/qmmm/chelpg.py | 5 +---- setup.py | 2 +- 10 files changed, 55 insertions(+), 40 deletions(-) diff --git a/gpu4pyscf/__config__.py b/gpu4pyscf/__config__.py index 93346e36..1bd1312d 100644 --- a/gpu4pyscf/__config__.py +++ b/gpu4pyscf/__config__.py @@ -4,7 +4,7 @@ GB = 1024*1024*1024 # such as A100-80G if props['totalGlobalMem'] >= 64 * GB: - min_ao_blksize = 256 + min_ao_blksize = 128 min_grid_blksize = 128*128 ao_aligned = 32 grid_aligned = 128 diff --git a/gpu4pyscf/__init__.py b/gpu4pyscf/__init__.py index 11890d28..ea5e78a2 100644 --- a/gpu4pyscf/__init__.py +++ b/gpu4pyscf/__init__.py @@ -1,5 +1,6 @@ from . import lib, grad, hessian, solvent, scf, dft -__version__ = '0.6.16' + +__version__ = '0.6.17' # monkey patch libxc reference due to a bug in nvcc from pyscf.dft import libxc diff --git a/gpu4pyscf/df/df.py b/gpu4pyscf/df/df.py index 14294843..f7e3217e 100644 --- a/gpu4pyscf/df/df.py +++ b/gpu4pyscf/df/df.py @@ -223,20 +223,21 @@ def cholesky_eri_gpu(intopt, mol, auxmol, cd_low, omega=None, sr_only=False): nj = j1 - j0 if sr_only: # TODO: in-place implementation or short-range kernel - ints_slices = cupy.empty([naoaux, nj, ni], order='C') + ints_slices = cupy.zeros([naoaux, nj, ni], order='C') for cp_kl_id, _ in enumerate(intopt.aux_log_qs): k0 = intopt.sph_aux_loc[cp_kl_id] k1 = intopt.sph_aux_loc[cp_kl_id+1] int3c2e.get_int3c2e_slice(intopt, cp_ij_id, cp_kl_id, out=ints_slices[k0:k1]) if omega is not None: - ints_slices_lr = cupy.empty([naoaux, nj, ni], order='C') + ints_slices_lr = cupy.zeros([naoaux, nj, ni], order='C') for cp_kl_id, _ in enumerate(intopt.aux_log_qs): k0 = intopt.sph_aux_loc[cp_kl_id] k1 = intopt.sph_aux_loc[cp_kl_id+1] int3c2e.get_int3c2e_slice(intopt, cp_ij_id, cp_kl_id, out=ints_slices[k0:k1], omega=omega) ints_slices -= ints_slices_lr else: - ints_slices = cupy.empty([naoaux, nj, ni], order='C') + # Initialization is required due to cutensor operations later + ints_slices = cupy.zeros([naoaux, nj, ni], order='C') for cp_kl_id, _ in enumerate(intopt.aux_log_qs): k0 = intopt.sph_aux_loc[cp_kl_id] k1 = intopt.sph_aux_loc[cp_kl_id+1] diff --git a/gpu4pyscf/df/grad/rhf.py b/gpu4pyscf/df/grad/rhf.py index ce29c904..f7b531ec 100644 --- a/gpu4pyscf/df/grad/rhf.py +++ b/gpu4pyscf/df/grad/rhf.py @@ -64,7 +64,7 @@ def get_jk(mf_grad, mol=None, dm0=None, hermi=0, with_j=True, with_k=True, omega mo_occ = cupy.asarray(mf_grad.base.mo_occ) sph_ao_idx = intopt.sph_ao_idx dm = take_last2d(dm0, sph_ao_idx) - orbo = contract('pi,i->pi', mo_coeff[:,mo_occ>0], numpy.sqrt(mo_occ[mo_occ>0])) + orbo = mo_coeff[:,mo_occ>0] * mo_occ[mo_occ>0] ** 0.5 orbo = orbo[sph_ao_idx, :] nocc = orbo.shape[-1] diff --git a/gpu4pyscf/lib/cupy_helper.py b/gpu4pyscf/lib/cupy_helper.py index 6a82e42d..42eea65c 100644 --- a/gpu4pyscf/lib/cupy_helper.py +++ b/gpu4pyscf/lib/cupy_helper.py @@ -376,8 +376,10 @@ def cart2sph(t, axis=0, ang=1, out=None): t_cart = t.reshape([i0*nli, li_size[0], i3]) if(out is not None): out = out.reshape([i0*nli, li_size[1], i3]) - t_sph = contract('min,ip->mpn', t_cart, c2s, out=out) - return t_sph.reshape(out_shape) + out[:] = cupy.einsum('min,ip->mpn', t_cart, c2s) + else: + out = cupy.einsum('min,ip->mpn', t_cart, c2s) + return out.reshape(out_shape) # a copy with modification from # https://github.com/pyscf/pyscf/blob/9219058ac0a1bcdd8058166cad0fb9127b82e9bf/pyscf/lib/linalg_helper.py#L1536 diff --git a/gpu4pyscf/lib/cutensor.py b/gpu4pyscf/lib/cutensor.py index 99dd194f..a54c1ffe 100644 --- a/gpu4pyscf/lib/cutensor.py +++ b/gpu4pyscf/lib/cutensor.py @@ -34,7 +34,7 @@ except ImportError: cutensor = None CUTENSOR_ALGO_DEFAULT = None - + def _create_mode_with_cache(mode): integer_mode = [] for x in mode: diff --git a/gpu4pyscf/lib/dftd3.py b/gpu4pyscf/lib/dftd3.py index 85e38eae..f92798eb 100644 --- a/gpu4pyscf/lib/dftd3.py +++ b/gpu4pyscf/lib/dftd3.py @@ -30,14 +30,19 @@ "d3op": libdftd3.dftd3_load_optimizedpower_damping #OptimizedPowerDampingParam, } -libdftd3.dftd3_new_error.restype = ctypes.c_void_p -libdftd3.dftd3_new_structure.restype = ctypes.c_void_p -libdftd3.dftd3_load_optimizedpower_damping.restype = ctypes.c_void_p -libdftd3.dftd3_load_mzero_damping.restype = ctypes.c_void_p -libdftd3.dftd3_load_mrational_damping.restype = ctypes.c_void_p -libdftd3.dftd3_load_zero_damping.restype = ctypes.c_void_p -libdftd3.dftd3_load_rational_damping.restype = ctypes.c_void_p -libdftd3.dftd3_new_d3_model.restype = ctypes.c_void_p +class _d3_restype(ctypes.Structure): + pass + +_d3_p = ctypes.POINTER(_d3_restype) + +libdftd3.dftd3_new_error.restype = _d3_p +libdftd3.dftd3_new_structure.restype = _d3_p +libdftd3.dftd3_load_optimizedpower_damping.restype = _d3_p +libdftd3.dftd3_load_mzero_damping.restype = _d3_p +libdftd3.dftd3_load_mrational_damping.restype = _d3_p +libdftd3.dftd3_load_zero_damping.restype = _d3_p +libdftd3.dftd3_load_rational_damping.restype = _d3_p +libdftd3.dftd3_new_d3_model.restype = _d3_p class DFTD3Dispersion(lib.StreamObject): def __init__(self, mol, xc, version='d3bj', atm=False): @@ -64,15 +69,15 @@ def __init__(self, mol, xc, version='d3bj', atm=False): err, ctypes.create_string_buffer(xc.encode(), size=50), ctypes.c_bool(atm)) - libdftd3.dftd3_delete_error(err) + + libdftd3.dftd3_delete_error(ctypes.byref(err)) def __del__(self): err = libdftd3.dftd3_new_error() - param = ctypes.cast(self._param, ctypes.c_void_p) - libdftd3.dftd3_delete_param(ctypes.byref(param)) - libdftd3.dftd3_delete_structure(err, self._mol) - libdftd3.dftd3_delete_model(err, self._disp) - libdftd3.dftd3_delete_error(err) + libdftd3.dftd3_delete_param(ctypes.byref(self._param)) + libdftd3.dftd3_delete_structure(err, ctypes.byref(self._mol)) + libdftd3.dftd3_delete_model(err, ctypes.byref(self._disp)) + libdftd3.dftd3_delete_error(ctypes.byref(err)) def get_dispersion(self, grad=False): res = {} @@ -102,5 +107,7 @@ def get_dispersion(self, grad=False): res.update(gradient=_gradient) if _sigma is not None: res.update(virial=_sigma) - libdftd3.dftd3_delete_error(err) + + libdftd3.dftd3_delete_error(ctypes.byref(err)) + return res \ No newline at end of file diff --git a/gpu4pyscf/lib/dftd4.py b/gpu4pyscf/lib/dftd4.py index 47defa15..f361c604 100644 --- a/gpu4pyscf/lib/dftd4.py +++ b/gpu4pyscf/lib/dftd4.py @@ -20,10 +20,15 @@ libdftd4 = np.ctypeslib.load_library('libdftd4', os.path.abspath(os.path.join(__file__, '..', 'deps', 'lib'))) -libdftd4.dftd4_new_error.restype = ctypes.c_void_p -libdftd4.dftd4_new_structure.restype = ctypes.c_void_p -libdftd4.dftd4_new_d4_model.restype = ctypes.c_void_p -libdftd4.dftd4_load_rational_damping.restype = ctypes.c_void_p +class _d4_restype(ctypes.Structure): + pass + +_d4_p = ctypes.POINTER(_d4_restype) + +libdftd4.dftd4_new_error.restype = _d4_p +libdftd4.dftd4_new_structure.restype = _d4_p +libdftd4.dftd4_new_d4_model.restype = _d4_p +libdftd4.dftd4_load_rational_damping.restype = _d4_p class DFTD4Dispersion(lib.StreamObject): def __init__(self, mol, xc, atm=False): @@ -52,15 +57,15 @@ def __init__(self, mol, xc, atm=False): err, ctypes.create_string_buffer(xc.encode(), size=50), ctypes.c_bool(atm)) - libdftd4.dftd4_delete_error(err) + + libdftd4.dftd4_delete_error(ctypes.byref(err)) def __del__(self): err = libdftd4.dftd4_new_error() - param = ctypes.cast(self._param, ctypes.c_void_p) - libdftd4.dftd4_delete_param(ctypes.byref(param)) - libdftd4.dftd4_delete_structure(err, self._mol) - libdftd4.dftd4_delete_model(err, self._disp) - libdftd4.dftd4_delete_error(err) + libdftd4.dftd4_delete_param(ctypes.byref(self._param)) + libdftd4.dftd4_delete_structure(err, ctypes.byref(self._mol)) + libdftd4.dftd4_delete_model(err, ctypes.byref(self._disp)) + libdftd4.dftd4_delete_error(ctypes.byref(err)) def get_dispersion(self, grad=False): res = {} @@ -90,5 +95,7 @@ def get_dispersion(self, grad=False): res.update(gradient=_gradient) if _sigma is not None: res.update(virial=_sigma) - libdftd4.dftd4_delete_error(err) + + libdftd4.dftd4_delete_error(ctypes.byref(err)) + return res \ No newline at end of file diff --git a/gpu4pyscf/qmmm/chelpg.py b/gpu4pyscf/qmmm/chelpg.py index 3851acc0..1c9de822 100644 --- a/gpu4pyscf/qmmm/chelpg.py +++ b/gpu4pyscf/qmmm/chelpg.py @@ -230,10 +230,7 @@ def build(self, cutoff=1e-14, group_size=None, ncptype = len(log_qs) self.bpcache = ctypes.POINTER(BasisProdCache)() - if diag_block_with_triu: - scale_shellpair_diag = 1. - else: - scale_shellpair_diag = 0.5 + scale_shellpair_diag = 1. libgint.GINTinit_basis_prod( ctypes.byref(self.bpcache), ctypes.c_double(scale_shellpair_diag), ao_loc.ctypes.data_as(ctypes.c_void_p), diff --git a/setup.py b/setup.py index 8ea88dd2..86cfe5d9 100755 --- a/setup.py +++ b/setup.py @@ -125,7 +125,7 @@ def initialize_with_default_plat_name(self): cmdclass={'build_py': CMakeBuildPy}, install_requires=[ 'pyscf>=2.4.0', - f'cupy-cuda{CUDA_VERSION}>=12.0', + f'cupy-cuda{CUDA_VERSION}>=12.3', 'geometric', f'gpu4pyscf-libxc-cuda{CUDA_VERSION}', ] From b50110783cfef3a1e123a2869920eb3647f7aa72 Mon Sep 17 00:00:00 2001 From: puzhichen <147788878+puzhichen@users.noreply.github.com> Date: Sat, 20 Jan 2024 08:43:47 +0800 Subject: [PATCH 10/10] Uks dev (#78) * Add the UHF module. * Add the UHF and unit test. * Reuse the _kernel in hf.py, and correct the unit test for UHF. * Add the UKS and a unit test for it. * Fix typos in test_uks.py --- gpu4pyscf/dft/tests/test_uks.py | 95 ++++++++++++ gpu4pyscf/dft/uks.py | 117 ++++++++++++++- gpu4pyscf/scf/diis.py | 14 +- gpu4pyscf/scf/tests/test_uhf.py | 255 ++++++++++++++++++++++++++++++++ gpu4pyscf/scf/uhf.py | 152 ++++++++++++++++++- 5 files changed, 621 insertions(+), 12 deletions(-) create mode 100644 gpu4pyscf/dft/tests/test_uks.py create mode 100644 gpu4pyscf/scf/tests/test_uhf.py diff --git a/gpu4pyscf/dft/tests/test_uks.py b/gpu4pyscf/dft/tests/test_uks.py new file mode 100644 index 00000000..afe51281 --- /dev/null +++ b/gpu4pyscf/dft/tests/test_uks.py @@ -0,0 +1,95 @@ +# gpu4pyscf is a plugin to use Nvidia GPU in PySCF package +# +# Copyright (C) 2022 Qiming Sun +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import numpy as np +import unittest +import pyscf +from gpu4pyscf.dft import uks + +atom = ''' +O 0.0000000000 -0.0000000000 0.1174000000 +H -0.7570000000 -0.0000000000 -0.4696000000 +H 0.7570000000 0.0000000000 -0.4696000000 +''' +bas='def2-qzvpp' +grids_level = 3 +nlcgrids_level = 1 + +def setUpModule(): + global mol + mol = pyscf.M( + atom=atom, + basis=bas, + max_memory=32000, + verbose = 1, + spin = 1, + charge = 1, + output = '/dev/null' + ) + +def tearDownModule(): + global mol + mol.stdout.close() + del mol + +def run_dft(xc): + mf = uks.UKS(mol, xc=xc) + mf.grids.level = grids_level + mf.nlcgrids.level = nlcgrids_level + e_dft = mf.kernel() + return e_dft + +class KnownValues(unittest.TestCase): + ''' + known values are obtained by pyscf + ''' + def test_uks_lda(self): + print('------- LDA ----------------') + e_tot = run_dft("LDA, vwn5") + assert np.allclose(e_tot, -75.42821982483972) + + def test_uks_pbe(self): + print('------- PBE ----------------') + e_tot = run_dft('PBE') + assert np.allclose(e_tot, -75.91732813416843) + + def test_uks_b3lyp(self): + print('-------- B3LYP -------------') + e_tot = run_dft('B3LYP') + assert np.allclose(e_tot, -76.00306439862237) + + def test_uks_m06(self): + print('--------- M06 --------------') + e_tot = run_dft("M06") + assert np.allclose(e_tot, -75.96551006522827) + + def test_uks_wb97(self): + print('-------- wB97 --------------') + e_tot = run_dft("HYB_GGA_XC_WB97") + assert np.allclose(e_tot, -75.987601337562) + + def test_uks_vv10(self): + print("------- wB97m-v -------------") + e_tot = run_dft('HYB_MGGA_XC_WB97M_V') + assert np.allclose(e_tot, -75.97363094678428) + + #TODO: add test cases for D3/D4 and gradient + +if __name__ == "__main__": + print("Full Tests for dft") + unittest.main() + diff --git a/gpu4pyscf/dft/uks.py b/gpu4pyscf/dft/uks.py index a3e48cb4..d56f067f 100644 --- a/gpu4pyscf/dft/uks.py +++ b/gpu4pyscf/dft/uks.py @@ -15,16 +15,121 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . +import cupy from pyscf.dft import uks -from gpu4pyscf.dft import numint -from gpu4pyscf.scf.uhf import UHF +from pyscf import lib +from gpu4pyscf import scf +from gpu4pyscf.lib import logger +from gpu4pyscf.dft import numint, gen_grid, rks +from gpu4pyscf.lib.cupy_helper import tag_array -class UKS(uks.UKS): + +def get_veff(ks, mol=None, dm=None, dm_last=0, vhf_last=0, hermi=1): + '''Coulomb + XC functional for UKS. See pyscf/dft/rks.py + :func:`get_veff` fore more details. + ''' + if mol is None: mol = ks.mol + if dm is None: dm = ks.make_rdm1() + t0 = logger.init_timer(ks) + rks.initialize_grids(ks, mol, dm) + + if hasattr(ks, 'screen_tol') and ks.screen_tol is not None: + ks.direct_scf_tol = ks.screen_tol + ground_state = (isinstance(dm, cupy.ndarray) and dm.ndim == 3) + + ni = ks._numint + if hermi == 2: # because rho = 0 + n, exc, vxc = (0,0), 0, 0 + else: + max_memory = ks.max_memory - lib.current_memory()[0] + n, exc, vxc = ni.nr_uks(mol, ks.grids, ks.xc, dm, max_memory=max_memory) + logger.debug(ks, 'nelec by numeric integration = %s', n) + if ks.nlc or ni.libxc.is_nlc(ks.xc): + if ni.libxc.is_nlc(ks.xc): + xc = ks.xc + else: + assert ni.libxc.is_nlc(ks.nlc) + xc = ks.nlc + n, enlc, vnlc = ni.nr_nlc_vxc(mol, ks.nlcgrids, xc, dm[0]+dm[1], + max_memory=max_memory) + exc += enlc + vxc += vnlc + logger.debug(ks, 'nelec with nlc grids = %s', n) + t0 = logger.timer(ks, 'vxc', *t0) + + if not ni.libxc.is_hybrid_xc(ks.xc): + vk = None + if (ks._eri is None and ks.direct_scf and + getattr(vhf_last, 'vj', None) is not None): + ddm = cupy.asarray(dm) - cupy.asarray(dm_last) + vj = ks.get_j(mol, ddm[0]+ddm[1], hermi) + vj += vhf_last.vj + else: + vj = ks.get_j(mol, dm[0]+dm[1], hermi) + vxc += vj + else: + omega, alpha, hyb = ni.rsh_and_hybrid_coeff(ks.xc, spin=mol.spin) + if (ks._eri is None and ks.direct_scf and + getattr(vhf_last, 'vk', None) is not None): + ddm = cupy.asarray(dm) - cupy.asarray(dm_last) + vj, vk = ks.get_jk(mol, ddm, hermi) + vk *= hyb + if abs(omega) > 1e-10: # For range separated Coulomb operator + vklr = ks.get_k(mol, ddm, hermi, omega) + vklr *= (alpha - hyb) + vk += vklr + vj = vj[0] + vj[1] + vhf_last.vj + vk += vhf_last.vk + else: + vj, vk = ks.get_jk(mol, dm, hermi) + vj = vj[0] + vj[1] + vk *= hyb + if abs(omega) > 1e-10: + vklr = ks.get_k(mol, dm, hermi, omega=omega) + vklr *= (alpha - hyb) + vk += vklr + vxc += vj - vk + + if ground_state: + exc -=(cupy.einsum('ij,ji', dm[0], vk[0]).real + + cupy.einsum('ij,ji', dm[1], vk[1]).real) * .5 + if ground_state: + ecoul = cupy.einsum('ij,ji', dm[0]+dm[1], vj).real * .5 + else: + ecoul = None + t0 = logger.timer_debug1(ks, 'jk total', *t0) + vxc = tag_array(vxc, ecoul=ecoul, exc=exc, vj=vj, vk=vk) + return vxc + + +def energy_elec(ks, dm=None, h1e=None, vhf=None): + if dm is None: dm = ks.make_rdm1() + if h1e is None: h1e = ks.get_hcore() + if vhf is None or getattr(vhf, 'ecoul', None) is None: + vhf = ks.get_veff(ks.mol, dm) + if not (isinstance(dm, cupy.ndarray) and dm.ndim == 2): + dm = dm[0] + dm[1] + return rks.energy_elec(ks, dm, h1e, vhf) + + +class UKS(scf.uhf.UHF, uks.UKS): from gpu4pyscf.lib.utils import to_cpu, to_gpu, device + _keys = {'disp', 'screen_tol'} - def __init__(self, mol, xc='LDA,VWN'): + def __init__(self, mol, xc='LDA,VWN', disp=None): super().__init__(mol, xc) + self.disp = disp self._numint = numint.NumInt() + self.screen_tol = 1e-14 + + grids_level = self.grids.level + self.grids = gen_grid.Grids(mol) + self.grids.level = grids_level - get_jk = UHF.get_jk - _eigh = UHF._eigh + nlcgrids_level = self.nlcgrids.level + self.nlcgrids = gen_grid.Grids(mol) + self.nlcgrids.level = nlcgrids_level + + energy_elec = energy_elec + get_veff = get_veff + diff --git a/gpu4pyscf/scf/diis.py b/gpu4pyscf/scf/diis.py index b1e4c1f7..c01d4972 100644 --- a/gpu4pyscf/scf/diis.py +++ b/gpu4pyscf/scf/diis.py @@ -15,7 +15,7 @@ # # Author: Qiming Sun # -# modified by Xiaojie Wu +# modified by Xiaojie Wu ; Zhichen Pu """ DIIS @@ -64,9 +64,13 @@ def get_num_vec(self): def get_err_vec(s, d, f): '''error vector = SDF - FDS''' - if isinstance(f, cupy.ndarray): + if isinstance(f, cupy.ndarray) and f.ndim == 2: sdf = reduce(cupy.dot, (s,d,f)) - errvec = (sdf.T.conj() - sdf) - return errvec + errvec = (sdf.conj().T - sdf).ravel() + elif f.ndim == s.ndim+1 and f.shape[0] == 2: # for UHF + errvec = cupy.hstack([ + get_err_vec(s, d[0], f[0]).ravel(), + get_err_vec(s, d[1], f[1]).ravel()]) else: - cpu_diis.get_err_vec(s, d, f) + raise RuntimeError('Unknown SCF DIIS type') + return errvec diff --git a/gpu4pyscf/scf/tests/test_uhf.py b/gpu4pyscf/scf/tests/test_uhf.py new file mode 100644 index 00000000..0fe8eb7b --- /dev/null +++ b/gpu4pyscf/scf/tests/test_uhf.py @@ -0,0 +1,255 @@ +# gpu4pyscf is a plugin to use Nvidia GPU in PySCF package +# +# Copyright (C) 2022 Qiming Sun +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import unittest +import numpy as np +import cupy +import pyscf +from pyscf import lib +from gpu4pyscf import scf + +mol = pyscf.M( + atom=''' +C -0.65830719, 0.61123287, -0.00800148 +C 0.73685281, 0.61123287, -0.00800148 +C 1.43439081, 1.81898387, -0.00800148 +C 0.73673681, 3.02749287, -0.00920048 +''', + basis='ccpvtz', + charge=1, + spin=1, + output = '/dev/null' +) + +mol1 = pyscf.M( + atom=''' +C -1.20806619, -0.34108413, -0.00755148 +C 1.28636081, -0.34128013, -0.00668648 +H 2.53407081, 1.81906387, -0.00736748 +H 1.28693681, 3.97963587, -0.00925948 +''', + basis='''unc +#BASIS SET: +H S + 1.815041 1 + 0.591063 1 +H P + 2.305000 1 +#BASIS SET: +C S + 8.383976 1 + 3.577015 1 + 1.547118 1 +H P + 2.305000 1 + 1.098827 1 + 0.806750 1 + 0.282362 1 +H D + 1.81900 1 + 0.72760 1 + 0.29104 1 +H F + 0.970109 1 +C G + 0.625000 1 +C H + 0.4 1 + ''', + output = '/dev/null' +) + +def tearDownModule(): + global mol, mol1 + mol.stdout.close() + mol1.stdout.close() + del mol, mol1 + +class KnownValues(unittest.TestCase): + def test_get_jk(self): + np.random.seed(1) + nao = mol.nao + dm = np.random.random((2,nao,nao)) + dm = dm + dm.transpose(0,2,1) + mf = scf.UHF(mol) + vj, vk = mf.get_jk(mol, dm) + self.assertAlmostEqual(lib.fp(vj), -1782.4478082102428, 7) + self.assertAlmostEqual(lib.fp(vk), -280.36548013781095, 7) + mf1 = mf.to_cpu() + refj, refk = mf1.get_jk(mol, dm) + self.assertAlmostEqual(abs(vj - refj).max(), 0, 7) + self.assertAlmostEqual(abs(vk - refk).max(), 0, 7) + with lib.temporary_env(mol, cart=True): + np.random.seed(1) + nao = mol.nao + dm = np.random.random((2, nao,nao)) + dm = dm + dm.transpose(0,2,1) + mf = scf.UHF(mol) + vj, vk = mf.get_jk(mol, dm) + self.assertAlmostEqual(lib.fp(vj), -1790.0063863999496, 7) + self.assertAlmostEqual(lib.fp(vk), -8.969890703683895 , 7) + + mf1 = mf.to_cpu() + refj, refk = mf1.get_jk(mol, dm) + self.assertAlmostEqual(abs(vj - refj).max(), 0, 7) + self.assertAlmostEqual(abs(vk - refk).max(), 0, 7) + + def test_get_j(self): + np.random.seed(1) + nao = mol.nao + dm = np.random.random((2,nao,nao)) + dm = dm + dm.transpose(0,2,1) + mf = scf.UHF(mol) + vj = mf.get_j(mol, dm) + self.assertAlmostEqual(lib.fp(vj), -1782.4478082102423 , 7) + + mf1 = mf.to_cpu() + refj = mf1.get_j(mol, dm) + self.assertAlmostEqual(abs(vj - refj).max(), 0, 7) + + with lib.temporary_env(mol, cart=True): + np.random.seed(1) + nao = mol.nao + dm = np.random.random((2,nao,nao)) + dm = dm + dm.transpose(0,2,1) + mf = scf.UHF(mol) + vj = mf.get_j(mol, dm) + self.assertAlmostEqual(lib.fp(vj), -1790.0063863999503, 7) + + mf1 = mf.to_cpu() + refj = mf1.get_j(mol, dm) + self.assertAlmostEqual(abs(vj - refj).max(), 0, 7) + + def test_get_k(self): + np.random.seed(1) + nao = mol.nao + dm = np.random.random((2,nao,nao)) + dm = dm + dm.transpose(0,2,1) + mf = scf.UHF(mol) + vk = mf.get_k(mol, dm) + self.assertAlmostEqual(lib.fp(vk), -280.36548013781083, 7) + + mf1 = mf.to_cpu() + refk = mf1.get_k(mol, dm) + self.assertAlmostEqual(abs(vk - refk).max(), 0, 7) + + with lib.temporary_env(mol, cart=True): + np.random.seed(1) + nao = mol.nao + dm = np.random.random((2,nao,nao)) + dm = dm + dm.transpose(0,2,1) + mf = scf.UHF(mol) + vk = mf.get_k(mol, dm) + self.assertAlmostEqual(lib.fp(vk), -8.969890703691519 , 7) + + mf1 = mf.to_cpu() + refk = mf1.get_k(mol, dm) + self.assertAlmostEqual(abs(vk - refk).max(), 0, 7) + + def test_get_jk1(self): + # test l >= 4 + np.random.seed(1) + nao = mol1.nao + dm = np.random.random((2,nao,nao)) + dm = dm + dm.transpose(0,2,1) + mf = scf.UHF(mol1) + vj, vk = mf.get_jk(mol1, dm, hermi=1) + self.assertAlmostEqual(lib.fp(vj), 179.14526555374763, 7) + self.assertAlmostEqual(lib.fp(vk), -34.851182918653606, 7) + + mf1 = mf.to_cpu() + refj, refk = mf1.get_jk(mol1, dm, hermi=1) + self.assertAlmostEqual(abs(vj - refj).max(), 0, 8) + self.assertAlmostEqual(abs(vk - refk).max(), 0, 8) + + @unittest.skip('hermi=0') + def test_get_jk1_hermi0(self): + np.random.seed(1) + nao = mol1.nao + dm = np.random.random((2,nao,nao)) + mf = scf.UHF(mol1) + vj, vk = mf.get_jk(mol1, cupy.asarray(dm), hermi=0) + self.assertAlmostEqual(lib.fp(vj.get()), 89.57263277687345 , 7) + self.assertAlmostEqual(lib.fp(vk.get()),-26.369697697245883, 7) + + mf1 = mf.to_cpu() + refj, refk = mf1.get_jk(mol1, dm, hermi=0) + self.assertAlmostEqual(abs(vj.get() - refj).max(), 0, 8) + self.assertAlmostEqual(abs(vk.get() - refk).max(), 0, 8) + + def test_get_j1(self): + # test l >= 4 + np.random.seed(1) + nao = mol1.nao + dm = np.random.random((2,nao,nao)) + dm = dm + dm.transpose(0,2,1) + mf = scf.UHF(mol1) + vj = mf.get_j(mol1, dm, hermi=1) + self.assertAlmostEqual(lib.fp(vj), 179.14526555374712, 7) + + mf1 = mf.to_cpu() + refj = mf1.get_j(mol1, dm, hermi=1) + self.assertAlmostEqual(abs(vj - refj).max(), 0, 7) + + @unittest.skip('hermi=0') + def test_get_j1_hermi0(self): + np.random.seed(1) + nao = mol1.nao + dm = np.random.random((2,nao,nao)) + mf = scf.UHF(mol1) + vj = mf.get_j(mol1, dm, hermi=0).get() + self.assertAlmostEqual(lib.fp(vj), 89.5726327768736, 7) + + mf1 = mf.to_cpu() + refj = mf1.get_j(mol1, dm, hermi=0) + self.assertAlmostEqual(abs(vj - refj).max(), 0, 7) + + def test_get_k1(self): + # test l >= 4 + np.random.seed(1) + nao = mol1.nao + dm = np.random.random((2,nao,nao)) + dm = dm + dm.transpose(0,2,1) + mf = scf.UHF(mol1) + vk = mf.get_k(mol1, dm, hermi=1) + self.assertAlmostEqual(lib.fp(vk), -34.85118291865315, 7) + + mf1 = mf.to_cpu() + refk = mf1.get_k(mol1, dm, hermi=1) + self.assertAlmostEqual(abs(vk - refk).max(), 0, 7) + + @unittest.skip('hermi=0') + def test_get_k1_hermi0(self): + np.random.seed(1) + nao = mol1.nao + dm = np.random.random((2,nao,nao)) + mf = scf.UHF(mol1) + vk = mf.get_k(mol1, dm, hermi=0).get() + self.assertAlmostEqual(lib.fp(vk),-26.369697697246007, 7) + + mf1 = mf.to_cpu() + refk = mf1.get_k(mol1, dm, hermi=0) + self.assertAlmostEqual(abs(vk - refk).max(), 0, 7) + + # end to end test + def test_uhf_scf(self): + e_tot = scf.UHF(mol).kernel() + self.assertAlmostEqual(e_tot, -150.76441654065087) + +if __name__ == "__main__": + print("Full Tests for UHF") + unittest.main() diff --git a/gpu4pyscf/scf/uhf.py b/gpu4pyscf/scf/uhf.py index 9566cf0b..d606d7dd 100644 --- a/gpu4pyscf/scf/uhf.py +++ b/gpu4pyscf/scf/uhf.py @@ -15,11 +15,161 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . + +from functools import reduce from pyscf.scf import uhf -from gpu4pyscf.scf.hf import _get_jk, eigh +from gpu4pyscf.scf.hf import _get_jk, eigh, damping, level_shift, _kernel +from gpu4pyscf.lib import logger +from gpu4pyscf.lib.cupy_helper import tag_array +import numpy as np +import cupy +from gpu4pyscf import lib +from gpu4pyscf.scf import diis +from pyscf import lib as pyscf_lib + + +def make_rdm1(mo_coeff, mo_occ, **kwargs): + '''One-particle density matrix in AO representation + + Args: + mo_coeff : tuple of 2D ndarrays + Orbital coefficients for alpha and beta spins. Each column is one orbital. + mo_occ : tuple of 1D ndarrays + Occupancies for alpha and beta spins. + Returns: + A list of 2D ndarrays for alpha and beta spins + ''' + mo_a = mo_coeff[0] + mo_b = mo_coeff[1] + dm_a = cupy.dot(mo_a*mo_occ[0], mo_a.conj().T) + dm_b = cupy.dot(mo_b*mo_occ[1], mo_b.conj().T) +# DO NOT make tag_array for DM here because the DM arrays may be modified and +# passed to functions like get_jk, get_vxc. These functions may take the tags +# (mo_coeff, mo_occ) to compute the potential if tags were found in the DM +# arrays and modifications to DM arrays may be ignored. + return tag_array((dm_a, dm_b), mo_coeff=mo_coeff, mo_occ=mo_occ) + + +def spin_square(mo, s=1): + r'''Spin square and multiplicity of UHF determinant + + Detailed derivataion please refers to the cpu pyscf. + + ''' + mo_a, mo_b = mo + nocc_a = mo_a.shape[1] + nocc_b = mo_b.shape[1] + s = reduce(cupy.dot, (mo_a.conj().T, cupy.asarray(s), mo_b)) + ssxy = (nocc_a+nocc_b) * .5 - cupy.einsum('ij,ij->', s.conj(), s) + ssz = (nocc_b-nocc_a)**2 * .25 + ss = (ssxy + ssz).real + s = cupy.sqrt(ss+.25) - .5 + return ss, s*2+1 + + +def get_fock(mf, h1e=None, s1e=None, vhf=None, dm=None, cycle=-1, diis=None, + diis_start_cycle=None, level_shift_factor=None, damp_factor=None): + if h1e is None: h1e = cupy.asarray(mf.get_hcore()) + if vhf is None: vhf = mf.get_veff(mf.mol, dm) + f = h1e + vhf + if f.ndim == 2: + f = (f, f) + if cycle < 0 and diis is None: # Not inside the SCF iteration + return f + + if diis_start_cycle is None: + diis_start_cycle = mf.diis_start_cycle + if level_shift_factor is None: + level_shift_factor = mf.level_shift + if damp_factor is None: + damp_factor = mf.damp + if s1e is None: s1e = mf.get_ovlp() + if dm is None: dm = mf.make_rdm1() + + if isinstance(level_shift_factor, (tuple, list, np.ndarray)): + shifta, shiftb = level_shift_factor + else: + shifta = shiftb = level_shift_factor + if isinstance(damp_factor, (tuple, list, np.ndarray)): + dampa, dampb = damp_factor + else: + dampa = dampb = damp_factor + + if 0 <= cycle < diis_start_cycle-1 and abs(dampa)+abs(dampb) > 1e-4: + f = (damping(s1e, dm[0], f[0], dampa), + damping(s1e, dm[1], f[1], dampb)) + if diis and cycle >= diis_start_cycle: + f = diis.update(s1e, dm, f, mf, h1e, vhf) + if abs(shifta)+abs(shiftb) > 1e-4: + f = (level_shift(s1e, dm[0], f[0], shifta), + level_shift(s1e, dm[1], f[1], shiftb)) + return f + class UHF(uhf.UHF): from gpu4pyscf.lib.utils import to_cpu, to_gpu, device + DIIS = diis.SCF_DIIS get_jk = _get_jk _eigh = staticmethod(eigh) + get_fock = get_fock + + def make_rdm1(self, mo_coeff=None, mo_occ=None, **kwargs): + if mo_coeff is None: + mo_coeff = self.mo_coeff + if mo_occ is None: + mo_occ = self.mo_occ + return make_rdm1(mo_coeff, mo_occ, **kwargs) + + def eig(self, fock, s): + e_a, c_a = self._eigh(fock[0], s) + e_b, c_b = self._eigh(fock[1], s) + return cupy.array((e_a,e_b)), cupy.array((c_a,c_b)) + + def get_veff(self, mol=None, dm=None, dm_last=0, vhf_last=0, hermi=1): + if mol is None: mol = self.mol + if dm is None: dm = self.make_rdm1() + + if isinstance(dm, cupy.ndarray) and dm.ndim == 2: + dm = cupy.asarray((dm*.5,dm*.5)) + + if self._eri is not None or not self.direct_scf: + vj, vk = self.get_jk(mol, cupy.asarray(dm), hermi) + vhf = vj[0] + vj[1] - vk + else: + ddm = cupy.asarray(dm) - cupy.asarray(dm_last) + vj, vk = self.get_jk(mol, ddm, hermi) + vhf = vj[0] + vj[1] - vk + vhf += cupy.asarray(vhf_last) + return vhf + + def scf(self, dm0=None, **kwargs): + cput0 = logger.init_timer(self) + + self.dump_flags() + self.build(self.mol) + + if self.max_cycle > 0 or self.mo_coeff is None: + self.converged, self.e_tot, \ + self.mo_energy, self.mo_coeff, self.mo_occ = \ + _kernel(self, self.conv_tol, self.conv_tol_grad, + dm0=dm0, callback=self.callback, + conv_check=self.conv_check, **kwargs) + else: + self.e_tot = _kernel(self, self.conv_tol, self.conv_tol_grad, + dm0=dm0, callback=self.callback, + conv_check=self.conv_check, **kwargs)[1] + + logger.timer(self, 'SCF', *cput0) + self._finalize() + return self.e_tot + kernel = pyscf_lib.alias(scf, alias_name='kernel') + + def spin_square(self, mo_coeff=None, s=None): + if mo_coeff is None: + mo_coeff = (self.mo_coeff[0][:,self.mo_occ[0]>0], + self.mo_coeff[1][:,self.mo_occ[1]>0]) + if s is None: + s = self.get_ovlp() + return spin_square(mo_coeff, s) +