Skip to content

Commit

Permalink
smd grad hess (pyscf#86)
Browse files Browse the repository at this point in the history
* cmake workflow for dftd3 and dftd4

* updated cmake

* new workflow for dftd3 and dftd4

* remove package_data

* resolve dependencies in dftd3 and dftd4

* add dftd3 and dftd4 to __init__.py

* updated the script for building wheels

* memory leak in dftd3 and dft4

* fixed memory leak in dftd3, debug for A800

* remove comments

* smd gradient

* fixed a bug in smd gradient

* tag_array for tuple of arrays

* fixed a bug in smd gradient

* increased # of quad points for SASA

* updated readme

* remove unnecessary header

* correct the order in set class

* add sasa_ng to unit test
  • Loading branch information
wxj6000 authored Jan 27, 2024
1 parent d331f87 commit 262bd21
Show file tree
Hide file tree
Showing 23 changed files with 1,416 additions and 506 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ Features
- Dispersion corrections via [DFTD3](https://github.com/dftd3/simple-dftd3) and [DFTD4](https://github.com/dftd4/dftd4);
- Nonlocal functional correction (vv10) for SCF and gradient;
- ECP is supported and calculated on CPU;
- PCM solvent models, analytical gradients, and semi-analytical Hessian matrix;
- SMD solvent models and solvation free energy
- PCM models, SMD model, their analytical gradients, and semi-analytical Hessian matrix;

Limitations
--------
Expand Down
108 changes: 108 additions & 0 deletions benchmarks/smd/benchmark_smd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright 2024 The GPU4PySCF Authors. All Rights Reserved.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import os
import numpy
from pyscf import gto
from gpu4pyscf.solvent import smd
from gpu4pyscf.solvent.grad import smd as smd_grad

'''
Benchmark CDS, gradient, hessian in the SMD model
'''

path = '../molecules/organic/'

# calculated with qchem 6.1, in kcal/mol
e_cds_qchem = {}
e_cds_qchem['water'] = {
'020_Vitamin_C.xyz': 5.0737,
'031_Inosine.xyz': 2.7129,
'033_Bisphenol_A.xyz': 6.2620,
'037_Mg_Porphin.xyz': 6.0393,
'042_Penicillin_V.xyz': 6.4349,
'045_Ochratoxin_A.xyz': 8.8526,
'052_Cetirizine.xyz': 4.6430,
'057_Tamoxifen.xyz': 5.4743,
'066_Raffinose.xyz': 10.2543,
'084_Sphingomyelin.xyz': 15.0308,
'095_Azadirachtin.xyz': 16.9321,
'113_Taxol.xyz': 17.2585,
'168_Valinomycin.xyz': 27.3499,
}

e_cds_qchem['ethanol'] = {
'020_Vitamin_C.xyz': 4.2119,
'031_Inosine.xyz': 1.0175,
'033_Bisphenol_A.xyz': -0.2454,
'037_Mg_Porphin.xyz': -2.2391,
'042_Penicillin_V.xyz': 1.8338,
'045_Ochratoxin_A.xyz': 1.0592,
'052_Cetirizine.xyz': -2.5099,
'057_Tamoxifen.xyz': -3.9320,
'066_Raffinose.xyz': 3.1120,
'084_Sphingomyelin.xyz': -3.1963,
'095_Azadirachtin.xyz': 6.5286,
'113_Taxol.xyz': 2.7271,
'168_Valinomycin.xyz': 4.0013,
}

def _check_energy_grad(filename, solvent='water'):
xyz = os.path.join(path, filename)
mol = gto.Mole(atom=xyz)
mol.build()
natm = mol.natm
fd_cds = numpy.zeros([natm,3])
eps = 1e-4
for ia in range(mol.natm):
for j in range(3):
coords = mol.atom_coords(unit='B')
coords[ia,j] += eps
mol.set_geom_(coords, unit='B')
mol.build()

smdobj = smd.SMD(mol)
smdobj.solvent = solvent
e0_cds = smdobj.get_cds()

coords[ia,j] -= 2.0*eps
mol.set_geom_(coords, unit='B')
mol.build()

smdobj = smd.SMD(mol)
smdobj.solvent = solvent
e1_cds = smdobj.get_cds()

coords[ia,j] += eps
mol.set_geom_(coords, unit='B')
fd_cds[ia,j] = (e0_cds - e1_cds) / (2.0 * eps)

smdobj = smd.SMD(mol)
smdobj.solvent = solvent
e_cds = smd.get_cds(smdobj) * smd.hartree2kcal
grad_cds = smd_grad.get_cds(smdobj)
print(f'e_cds by GPU4PySCF: {e_cds}')
print(f'e_cds by Q-Chem: {e_cds_qchem[solvent][filename]}')
print(f'e_cds(Q-Chem) - e_cds(GPU4PySCF): {e_cds - e_cds_qchem[solvent][filename]}')
print(f'norm (fd gradient - analy gradient: {numpy.linalg.norm(fd_cds - grad_cds.get())}')
assert numpy.linalg.norm(fd_cds - grad_cds.get()) < 1e-8

if __name__ == "__main__":
for filename in os.listdir(path):
print(f'---- benchmarking {filename} ----------')
print('in water')
_check_energy_grad(filename, solvent='water')
print('in ethanol')
_check_energy_grad(filename, solvent='ethanol')
2 changes: 1 addition & 1 deletion examples/14-pcm_solvent.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
mf.grids.atom_grid = (99,590)
mf.small_rho_cutoff = 1e-10
mf.with_solvent.lebedev_order = 29 # 302 Lebedev grids
mf.with_solvent.method = 'C-PCM'
mf.with_solvent.method = 'IEF-PCM'
mf.with_solvent.eps = 78.3553
mf.kernel()

Expand Down
13 changes: 10 additions & 3 deletions examples/16-smd.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,22 @@
H -0.7570000000 -0.0000000000 -0.4696000000
H 0.7570000000 0.0000000000 -0.4696000000
'''
mol = pyscf.M(atom=atom, basis='def2-tzvpp', verbose=4)
atom = 'Vitamin_C.xyz'
mol = pyscf.M(atom=atom, basis='def2-tzvpp', verbose=6)

mf = dft.rks.RKS(mol, xc='HYB_GGA_XC_B3LYP')#.density_fit()
mf = dft.rks.RKS(mol, xc='HYB_GGA_XC_B3LYP').density_fit()
mf = mf.SMD()
mf.verbose = 4
mf.verbose = 6
mf.grids.atom_grid = (99,590)
mf.small_rho_cutoff = 1e-10
mf.with_solvent.lebedev_order = 29 # 302 Lebedev grids
mf.with_solvent.method = 'SMD'
mf.with_solvent.solvent = 'water'
e_tot = mf.kernel()
print('total energy with SMD:', e_tot)

gradobj = mf.nuc_grad_method()
f = gradobj.kernel()

hessobj = mf.Hessian()
h = hessobj.kernel()
5 changes: 5 additions & 0 deletions gpu4pyscf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,8 @@
# monkey patch libxc reference due to a bug in nvcc
from pyscf.dft import libxc
libxc.__reference__ = 'unable to decode the reference due to https://github.com/NVIDIA/cuda-python/issues/29'

from gpu4pyscf.lib.utils import patch_cpu_kernel
from gpu4pyscf.lib.cupy_helper import tag_array
from pyscf import lib
lib.tag_array = tag_array
45 changes: 30 additions & 15 deletions gpu4pyscf/lib/cupy_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@

import os
import sys
import functools
import numpy as np
import cupy
import ctypes
from pyscf import lib
from gpu4pyscf.lib import logger
from gpu4pyscf.gto import mole
from gpu4pyscf.lib.cutensor import contract
Expand Down Expand Up @@ -93,11 +95,22 @@ def device2host_2d(a_cpu, a_gpu, stream=None):
class CPArrayWithTag(cupy.ndarray):
pass

@functools.wraps(lib.tag_array)
def tag_array(a, **kwargs):
''' attach attributes to cupy ndarray'''
t = cupy.asarray(a).view(CPArrayWithTag)
if isinstance(a, CPArrayWithTag):
t.__dict__.update(a.__dict__)
'''
a should be cupy/numpy array or tuple of cupy/numpy array
attach attributes to cupy ndarray for cupy array
attach attributes to numpy ndarray for numpy array
'''
if isinstance(a, cupy.ndarray) or isinstance(a[0], cupy.ndarray):
t = cupy.asarray(a).view(CPArrayWithTag)
if isinstance(a, CPArrayWithTag):
t.__dict__.update(a.__dict__)
else:
t = np.asarray(a).view(lib.NPArrayWithTag)
if isinstance(a, lib.NPArrayWithTag):
t.__dict__.update(a.__dict__)
t.__dict__.update(kwargs)
return t

Expand Down Expand Up @@ -168,19 +181,23 @@ def add_sparse(a, b, indices):
raise RuntimeError('failed in sparse_add2d')
return a

def dist_matrix(coords, out=None):
assert coords.flags.c_contiguous
n = coords.shape[0]
def dist_matrix(x, y, out=None):
assert x.flags.c_contiguous
assert y.flags.c_contiguous

m = x.shape[0]
n = y.shape[0]
if out is None:
out = cupy.empty([n,n])
out = cupy.empty([m,n])

stream = cupy.cuda.get_current_stream()
err = libcupy_helper.dist_matrix(
ctypes.cast(stream.ptr, ctypes.c_void_p),
ctypes.cast(out.data.ptr, ctypes.c_void_p),
ctypes.cast(coords.data.ptr, ctypes.c_void_p),
ctypes.cast(coords.data.ptr, ctypes.c_void_p),
ctypes.c_int(n),
ctypes.cast(x.data.ptr, ctypes.c_void_p),
ctypes.cast(y.data.ptr, ctypes.c_void_p),
ctypes.c_int(m),
ctypes.c_int(n)
)
if err != 0:
raise RuntimeError('failed in calculating distance matrix')
Expand Down Expand Up @@ -376,10 +393,8 @@ def cart2sph(t, axis=0, ang=1, out=None):
t_cart = t.reshape([i0*nli, li_size[0], i3])
if(out is not None):
out = out.reshape([i0*nli, li_size[1], i3])
out[:] = cupy.einsum('min,ip->mpn', t_cart, c2s)
else:
out = cupy.einsum('min,ip->mpn', t_cart, c2s)
return out.reshape(out_shape)
t_sph = contract('min,ip->mpn', t_cart, c2s, out=out)
return t_sph.reshape(out_shape)

# a copy with modification from
# https://github.com/pyscf/pyscf/blob/9219058ac0a1bcdd8058166cad0fb9127b82e9bf/pyscf/lib/linalg_helper.py#L1536
Expand Down
13 changes: 7 additions & 6 deletions gpu4pyscf/lib/cupy_helper/dist_matrix.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
#define THREADS 32

__global__
static void _calc_distances(double *dist, const double *x, const double *y, int n)
static void _calc_distances(double *dist, const double *x, const double *y, int m, int n)
{
int i = blockIdx.x * blockDim.x + threadIdx.x;
int j = blockIdx.y * blockDim.y + threadIdx.y;
if (i >= n || j >= n){
if (i >= m || j >= n){
return;
}

Expand All @@ -34,12 +34,13 @@ static void _calc_distances(double *dist, const double *x, const double *y, int
}

extern "C" {
int dist_matrix(cudaStream_t stream, double *dist, const double *x, const double *y, int n)
int dist_matrix(cudaStream_t stream, double *dist, const double *x, const double *y, int m, int n)
{
int ntile = (n + THREADS - 1) / THREADS;
int ntilex = (m + THREADS - 1) / THREADS;
int ntiley = (n + THREADS - 1) / THREADS;
dim3 threads(THREADS, THREADS);
dim3 blocks(ntile, ntile);
_calc_distances<<<blocks, threads, 0, stream>>>(dist, x, y, n);
dim3 blocks(ntilex, ntiley);
_calc_distances<<<blocks, threads, 0, stream>>>(dist, x, y, m, n);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
return 1;
Expand Down
Loading

0 comments on commit 262bd21

Please sign in to comment.