Skip to content

Commit

Permalink
Remove dependency of params_shape on template_params (Issue #62)
Browse files Browse the repository at this point in the history
What:
1. Turn params_shape into an abstract property
2. Use params_shape instead of template_params to compute nparams in
   BaseWavefunction
3. Use nparams in BaseWavefunction.load_cache instead of params.size
4. Move around tests
Why:
1. Every time param_shape is called, template_params needs to be
   constructed. For most wavefunctions, this is not a big problem, but
   some wavefunctions actually need to do a fair bit to obtain a
   template (e.g. APr2G). The consistency between params_shape and
   template_params will be checked implicitly through assign_params,
   which checks that the given parameters has the same shape as the
   params_shape.

Fix up BaseWavefunction.nparams
  • Loading branch information
kimt33 committed Jul 1, 2018
1 parent 6f6d22c commit 9a80738
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 24 deletions.
30 changes: 15 additions & 15 deletions wfns/wfn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,19 +116,7 @@ def nparams(self):
Number of parameters.
"""
return self.template_params.size

@property
def params_shape(self):
"""Return the shape of the wavefunction parameters.
Returns
-------
params_shape : tuple of int
Shape of the parameters.
"""
return self.template_params.shape
return np.prod(self.params_shape)

@property
def spin(self):
Expand Down Expand Up @@ -357,7 +345,7 @@ def load_cache(self):
if self.memory == np.inf:
memory = None
else:
memory = int((self.memory - 5*8*self.params.size) / (self.params.size + 1))
memory = int((self.memory - 5 * 8 * self.nparams) / (self.nparams + 1))

# create function that will be cached
@functools.lru_cache(maxsize=memory, typed=False)
Expand Down Expand Up @@ -453,6 +441,18 @@ def clear_cache(self, key=None):
raise AttributeError('Given cached function does not have decorator '
'`functools.lru_cache`') from error

@abc.abstractproperty
def params_shape(self):
"""Return the shape of the wavefunction parameters.
Returns
-------
params_shape : tuple of int
Shape of the parameters.
"""
pass

@abc.abstractproperty
def template_params(self):
"""Return the template of the parameters of the given wavefunction.
Expand All @@ -464,7 +464,7 @@ def template_params(self):
Notes
-----
May depend on other attributes or properties.
May depend on params_shape and other attributes/properties.
"""
pass
Expand Down
27 changes: 18 additions & 9 deletions wfns/wfn/test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ def __init__(self):
def get_overlap(self):
pass

@property
def params_shape(self):
return (10, 10)

@property
def template_params(self):
return np.identity(10)
Expand Down Expand Up @@ -124,9 +128,20 @@ def test_assign_params():
assert not np.allclose(np.real(test.params), np.identity(10))
assert not np.allclose(np.imag(test.params), np.zeros((10, 10)))

# FIXME: hard to test because property of a class/instance cannot be overwritten easily
# test.template_params = np.array([[0.0]])
# test.assign_params(2)
# for testing one line of code
class TempTestWavefunction(TestWavefunction):
@property
def params_shape(self):
return (1, 1, 1)

@property
def template_params(self):
return np.array([[[0.0]]])

test = TempTestWavefunction()
test.assign_dtype(complex)
test.assign_params(2.0)
assert test.params.shape == (1, 1, 1)


def test_init():
Expand Down Expand Up @@ -215,12 +230,6 @@ def test_nparams():
assert test.nparams == 100


def test_params_shape():
"""Test BaseWavefunction.params_shape."""
test = TestWavefunction()
assert test.params_shape == (10, 10)


def test_spin():
"""Test BaseWavefunction.spin"""
test = TestWavefunction()
Expand Down

0 comments on commit 9a80738

Please sign in to comment.