Skip to content

Commit

Permalink
fixed a bug in SMD gradient and Hessian
Browse files Browse the repository at this point in the history
  • Loading branch information
wxj6000 committed Mar 7, 2024
1 parent 73493fa commit f2edd12
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
9 changes: 5 additions & 4 deletions gpu4pyscf/solvent/grad/smd.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,14 +255,15 @@ def kernel(self, *args, dm=None, atmlst=None, **kwargs):
dm = kwargs.pop('dm', None)
if dm is None:
dm = self.base.make_rdm1(ao_repr=True)

self.de_solute = super().kernel(*args, **kwargs)
if dm.ndim == 3:
dm = dm[0] + dm[1]
self.de_solute = super().kernel(*args, **kwargs)
self.de_solvent = pcm_grad.grad_qv(self.base.with_solvent, dm)
self.de_solvent+= pcm_grad.grad_solver(self.base.with_solvent, dm)
self.de_solvent+= pcm_grad.grad_nuc(self.base.with_solvent, dm)
self.de_cds = get_cds(self.base.with_solvent)
self.de = self.de_solute + self.de_solvent + self.de_cds

self.de = self.de_solute + self.de_solvent
self.de += get_cds(self.base.with_solvent)
if self.verbose >= logger.NOTE:
logger.note(self, '--------------- %s (+%s) gradients ---------------',
self.base.__class__.__name__,
Expand Down
6 changes: 4 additions & 2 deletions gpu4pyscf/solvent/hessian/smd.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,13 +276,15 @@ def kernel(self, *args, dm=None, atmlst=None, **kwargs):
dm = kwargs.pop('dm', None)
if dm is None:
dm = self.base.make_rdm1(ao_repr=True)
if dm.ndim == 3:
dm = dm[0] + dm[1]
is_equilibrium = self.base.with_solvent.equilibrium_solvation
self.base.with_solvent.equilibrium_solvation = True
self.de_solvent = pcm_hess.hess_elec(self.base.with_solvent, dm, verbose=self.verbose)
#self.de_solvent+= hess_nuc(self.base.with_solvent)
self.de_solute = super().kernel(*args, **kwargs)
self.de = self.de_solute + self.de_solvent
self.de += get_cds(self.base.with_solvent)
self.de_cds = get_cds(self.base.with_solvent)
self.de = self.de_solute + self.de_solvent + self.de_cds
self.base.with_solvent.equilibrium_solvation = is_equilibrium
return self.de

Expand Down

0 comments on commit f2edd12

Please sign in to comment.