Skip to content

Commit

Permalink
Add pickle serialization (pyscf#294)
Browse files Browse the repository at this point in the history
* Add pickle serialization (fix pyscf#267)

* syntax error

* Fix DFHF serialization tests

---------

Co-authored-by: Qiming Sun <[email protected]>
  • Loading branch information
sunqm and Qiming Sun authored Dec 28, 2024
1 parent ad52eba commit 177fb05
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 7 deletions.
6 changes: 5 additions & 1 deletion gpu4pyscf/df/df.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
class DF(lib.StreamObject):
from gpu4pyscf.lib.utils import to_gpu, device

_keys = {'intopt', 'mol', 'auxmol', 'use_gpu_memory'}
_keys = {'intopt', 'nao', 'naux', 'cd_low', 'mol', 'auxmol', 'use_gpu_memory'}

def __init__(self, mol, auxbasis=None):
self.mol = mol
Expand All @@ -52,8 +52,12 @@ def __init__(self, mol, auxbasis=None):
self.naux = None
self.cd_low = None
self._cderi = None
self._vjopt = None
self._rsh_df = {}

__getstate__, __setstate__ = lib.generate_pickle_methods(
excludes=('cd_low', 'intopt', '_cderi', '_vjopt'))

@property
def auxbasis(self):
return self._auxbasis
Expand Down
3 changes: 1 addition & 2 deletions gpu4pyscf/df/df_jk.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class _DFHF:
to_gpu = utils.to_gpu
device = utils.device
__name_mixin__ = 'DF'
_keys = {'rhoj', 'rhok', 'disp', 'screen_tol'}
_keys = {'rhoj', 'rhok', 'disp', 'screen_tol', 'with_df', 'only_dfj'}

def __init__(self, mf, dfobj, only_dfj):
self.__dict__.update(mf.__dict__)
Expand All @@ -132,7 +132,6 @@ def __init__(self, mf, dfobj, only_dfj):
self.direct_scf = False
self.with_df = dfobj
self.only_dfj = only_dfj
self._keys = mf._keys.union(['with_df', 'only_dfj'])

def undo_df(self):
'''Remove the DFHF Mixin'''
Expand Down
12 changes: 11 additions & 1 deletion gpu4pyscf/df/tests/test_df_rhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,17 @@
# limitations under the License.

import unittest
import pickle
import numpy as np
import pyscf
from pyscf import scf as cpu_scf
from pyscf.df import df_jk as cpu_df_jk
from gpu4pyscf.df import df_jk as gpu_df_jk
from gpu4pyscf import scf as gpu_scf
try:
import cloudpickle
except ImportError:
cloudpickle = None

atom = '''
O 0.0000000000 -0.0000000000 0.1174000000
Expand Down Expand Up @@ -48,12 +53,17 @@ class KnownValues(unittest.TestCase):
'''
def test_rhf(self):
print('------- RHF -----------------')
mf = gpu_scf.RHF(mol_sph).density_fit(auxbasis='def2-tzvpp-jkfit')
mf = mol_sph.RHF().density_fit(auxbasis='def2-tzvpp-jkfit').to_gpu()
e_tot = mf.kernel()
e_qchem = -76.0624582299
print(f'diff from qchem {e_tot - e_qchem}')
assert np.abs(e_tot - e_qchem) < 1e-5

# test serialization
if cloudpickle is not None:
mf1 = pickle.loads(cloudpickle.dumps(mf))
assert mf1.e_tot == e_tot

def test_cart(self):
print('------- RHF Cart -----------------')
mf = gpu_scf.RHF(mol_cart).density_fit(auxbasis='def2-tzvpp-jkfit')
Expand Down
3 changes: 3 additions & 0 deletions gpu4pyscf/dft/numint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1792,6 +1792,9 @@ class NumInt(lib.StreamObject, LibXCMixin):
screen_index = None
xcfuns = None # can be multiple xc functionals

__getstate__, __setstate__ = lib.generate_pickle_methods(
excludes=('gdftopt',))

def build(self, mol, coords):
self.gdftopt = _GDFTOpt.from_mol(mol)
self.grid_blksize = None
Expand Down
10 changes: 8 additions & 2 deletions gpu4pyscf/dft/rks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
# limitations under the License.

# modified by Xiaojie Wu ([email protected])

import cupy
from pyscf.dft import rks

from gpu4pyscf.lib import logger
from gpu4pyscf.dft import numint, gen_grid
from gpu4pyscf.scf import hf
Expand Down Expand Up @@ -257,6 +257,7 @@ def __init__(self, xc='LDA,VWN'):
##################################################
# don't modify the following attributes, they are not input options
self._numint = numint.NumInt()

@property
def omega(self):
return self._numint.omega
Expand Down Expand Up @@ -291,8 +292,13 @@ def reset(self, mol=None):
hf.SCF.reset(self, mol)
self.grids.reset(mol)
self.nlcgrids.reset(mol)
self.cphf_grids.reset(mol)
self._numint.reset()
# The cphf_grids attribute is not available in the PySCF CPU version.
# In PySCF's to_gpu() function, this attribute is not properly
# initialized. mol of the KS object must be used for initialization.
if mol is None:
mol = self.mol
self.cphf_grids.reset(mol)
return self

def nuc_grad_method(self):
Expand Down
10 changes: 9 additions & 1 deletion gpu4pyscf/dft/tests/test_rks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pickle
import numpy as np
import unittest
import pyscf
Expand Down Expand Up @@ -64,11 +65,18 @@ class KnownValues(unittest.TestCase):
'''
def test_rks_lda(self):
print('------- LDA ----------------')
e_tot = run_dft("LDA, vwn5", mol_sph)
mf = mol_sph.RKS(xc='LDA,vwn5').to_gpu()
mf.grids.level = grids_level
mf.nlcgrids.level = nlcgrids_level
e_tot = mf.kernel()
e_ref = -75.9046410402
print('| CPU - GPU |:', e_tot - e_ref)
assert np.abs(e_tot - e_ref) < 1e-5

# test serialization
mf1 = pickle.loads(pickle.dumps(mf))
assert mf1.e_tot == e_tot

def test_rks_pbe(self):
print('------- PBE ----------------')
e_tot = run_dft('PBE', mol_sph)
Expand Down
3 changes: 3 additions & 0 deletions gpu4pyscf/scf/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,9 @@ def __init__(self, mol):
self._opt_gpu = {None: None}
self._eri = None # Note: self._eri requires large amount of memory

__getstate__, __setstate__ = pyscf_lib.generate_pickle_methods(
excludes=('_opt_gpu', '_eri', '_numint'))

def check_sanity(self):
s1e = self.get_ovlp()
if isinstance(s1e, cupy.ndarray) and s1e.ndim == 2:
Expand Down

0 comments on commit 177fb05

Please sign in to comment.