Skip to content

Commit

Permalink
Merge pull request #520 from reneeotten/pytest
Browse files Browse the repository at this point in the history
test-suite: use pytest features, improve coverage, fix mistakes
  • Loading branch information
newville authored Nov 29, 2018
2 parents 7fec373 + 852685a commit 1268177
Show file tree
Hide file tree
Showing 8 changed files with 487 additions and 406 deletions.
13 changes: 7 additions & 6 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,16 @@ before_install:
- conda info -a

install:
- if [[ $version == minimum && $TRAVIS_PYTHON_VERSION != 3.6 && $TRAVIS_PYTHON_VERSION != 3.7-dev ]]; then conda create -q -n test_env python=$TRAVIS_PYTHON_VERSION numpy=1.10 scipy=0.17 six=1.10 pytest; fi
- if [[ $version == minimum && $TRAVIS_PYTHON_VERSION == 3.6 ]]; then conda create -q -n test_env python=$TRAVIS_PYTHON_VERSION numpy=1.11.2 scipy=0.18 six=1.10 pytest; fi
- if [[ $version == minimum && $TRAVIS_PYTHON_VERSION == 3.7-dev ]]; then conda create -q -n test_env python=3.7 numpy=1.11.3 scipy=1.1 six=1.11 pytest; fi
- if [[ $version == latest && $TRAVIS_PYTHON_VERSION != 3.7-dev ]]; then conda create -q -n test_env python=$TRAVIS_PYTHON_VERSION numpy scipy six pandas matplotlib dill pytest; fi
- if [[ $version == latest && $TRAVIS_PYTHON_VERSION == 3.7-dev ]]; then conda create -q -n test_env python=3.7 numpy scipy six pandas matplotlib dill pytest; fi
- if [[ $version == minimum && $TRAVIS_PYTHON_VERSION != 3.6 && $TRAVIS_PYTHON_VERSION != 3.7-dev ]]; then conda create -q -n test_env python=$TRAVIS_PYTHON_VERSION numpy=1.10 scipy=0.17 six=1.10; fi
- if [[ $version == minimum && $TRAVIS_PYTHON_VERSION == 3.6 ]]; then conda create -q -n test_env python=$TRAVIS_PYTHON_VERSION numpy=1.11.2 scipy=0.18 six=1.10; fi
- if [[ $version == minimum && $TRAVIS_PYTHON_VERSION == 3.7-dev ]]; then conda create -q -n test_env python=3.7 numpy=1.11.3 scipy=1.1 six=1.11; fi
- if [[ $version == latest && $TRAVIS_PYTHON_VERSION != 3.7-dev ]]; then conda create -q -n test_env python=$TRAVIS_PYTHON_VERSION numpy scipy six pandas matplotlib dill; fi
- if [[ $version == latest && $TRAVIS_PYTHON_VERSION == 3.7-dev ]]; then conda create -q -n test_env python=3.7 numpy scipy six pandas matplotlib dill; fi
- source activate test_env
- pip install pytest
- pip install asteval
- pip install emcee
- pip install uncertainties
- if [[ $version == latest ]]; then pip install emcee numdifftools; fi
- python setup.py install
- conda list

Expand Down
55 changes: 24 additions & 31 deletions lmfit/minimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def flatchain(self):
def show_candidates(self, candidate_nmb='all'):
"""Show pretty_print() representation of candidates from `brute` method.
Showing candidates (default is 'all') or the specified candidate-#
Showing all stored candidates (default) or the specified candidate-#
from the `brute` method.
Parameters
Expand All @@ -319,16 +319,19 @@ def show_candidates(self, candidate_nmb='all'):
"""
if hasattr(self, 'candidates'):
try:
candidate = self.candidates[candidate_nmb]
print("\nCandidate #{}, chisqr = "
"{:.3f}".format(candidate_nmb, candidate.score))
candidate.params.pretty_print()
except IndexError:
if candidate_nmb == 'all':
for i, candidate in enumerate(self.candidates):
print("\nCandidate #{}, chisqr = "
"{:.3f}".format(i, candidate.score))
"{:.3f}".format(i+1, candidate.score))
candidate.params.pretty_print()
elif (candidate_nmb < 1 or candidate_nmb > len(self.candidates)):
raise ValueError("'candidate_nmb' should be between 1 and {}."
.format(len(self.candidates)))
else:
candidate = self.candidates[candidate_nmb-1]
print("\nCandidate #{}, chisqr = "
"{:.3f}".format(candidate_nmb, candidate.score))
candidate.params.pretty_print()

def _calculate_statistics(self):
"""Calculate the fitting statistics."""
Expand Down Expand Up @@ -918,13 +921,11 @@ def scalar_minimize(self, method='Nelder-Mead', params=None, **kws):
result.x = np.atleast_1d(result.x)
result.residual = self.__residual(result.x)
result.nfev -= 1
else:
pass

result._calculate_statistics()

# calculate the cov_x and estimate uncertanties/correlations
if (self.calc_covar and HAS_NUMDIFFTOOLS and
if (not result.aborted and self.calc_covar and HAS_NUMDIFFTOOLS and
len(result.residual) > len(result.var_names)):
_covar_ndt = self._calculate_covariance_matrix(result.x)
if _covar_ndt is not None:
Expand Down Expand Up @@ -1373,27 +1374,25 @@ def least_squares(self, params=None, **kws):
bounds=(lower_bounds, upper_bounds),
kwargs=dict(apply_bounds_transformation=False),
**kws)
result.residual = ret.fun
except AbortFitException:
pass

result._calculate_statistics()

if not result.aborted:
for attr in ret:
setattr(result, attr, ret[attr])

result.x = np.atleast_1d(result.x)
result.residual = ret.fun
else:
pass

result._calculate_statistics()

# calculate the cov_x and estimate uncertainties/correlations
try:
hess = np.matmul(ret.jac.T, ret.jac)
result.covar = np.linalg.inv(hess)
self._calculate_uncertainties_correlations()
except LinAlgError:
pass
# calculate the cov_x and estimate uncertainties/correlations
try:
hess = np.matmul(ret.jac.T, ret.jac)
result.covar = np.linalg.inv(hess)
self._calculate_uncertainties_correlations()
except LinAlgError:
pass

return result

Expand Down Expand Up @@ -1556,13 +1555,11 @@ def basinhopping(self, params=None, **kws):
result.message = ret.message
result.residual = self.__residual(ret.x)
result.nfev -= 1
else:
pass

result._calculate_statistics()

# calculate the cov_x and estimate uncertanties/correlations
if (self.calc_covar and HAS_NUMDIFFTOOLS and
if (not result.aborted and self.calc_covar and HAS_NUMDIFFTOOLS and
len(result.residual) > len(result.var_names)):
_covar_ndt = self._calculate_covariance_matrix(ret.x)
if _covar_ndt is not None:
Expand Down Expand Up @@ -1726,8 +1723,6 @@ def brute(self, params=None, Ns=20, keep=50):
result.params = result.candidates[0].params
result.residual = self.__residual(result.brute_x0, apply_bounds_transformation=False)
result.nfev -= 1
else:
pass

result._calculate_statistics()

Expand Down Expand Up @@ -1833,13 +1828,11 @@ def ampgo(self, params=None, **kws):

result.residual = self.__residual(result.ampgo_x0)
result.nfev -= 1
else:
pass

result._calculate_statistics()

# calculate the cov_x and estimate uncertanties/correlations
if (self.calc_covar and HAS_NUMDIFFTOOLS and
if (not result.aborted and self.calc_covar and HAS_NUMDIFFTOOLS and
len(result.residual) > len(result.var_names)):
_covar_ndt = self._calculate_covariance_matrix(result.ampgo_x0)
if _covar_ndt is not None:
Expand Down
22 changes: 22 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import numpy as np
import pytest

import lmfit


@pytest.fixture
def minimizer_Alpine02():
"""Return a lmfit Minimizer object for the Alpine02 function."""
def residual_Alpine02(params):
x0 = params['x0'].value
x1 = params['x1'].value
return np.prod(np.sqrt(x0) * np.sin(x0)) * np.prod(np.sqrt(x1) *
np.sin(x1))

# create Parameters and set initial values and bounds
pars = lmfit.Parameters()
pars.add_many(('x0', 1., True, 0.0, 10.0),
('x1', 1., True, 0.0, 10.0))

mini = lmfit.Minimizer(residual_Alpine02, pars)
return mini
111 changes: 83 additions & 28 deletions tests/test_ampgo.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,104 @@
"""Tests for the AMPGO global minimization algorithm."""
import sys

import numpy as np
from numpy.testing import assert_allclose
import pytest

import lmfit

# correct result for Alpine02 function
global_optimum = [7.91705268, 4.81584232]
fglob = -6.12950


def test_ampgo_Alpine02():
@pytest.mark.parametrize("tabustrategy", ['farthest', 'oldest'])
def test_ampgo_Alpine02(minimizer_Alpine02, tabustrategy):
"""Test AMPGO algorithm on Alpine02 function."""
kws = {'tabustrategy': tabustrategy}
out = minimizer_Alpine02.minimize(method='ampgo', **kws)
out_x = np.array([out.params['x0'].value, out.params['x1'].value])

assert_allclose(out.residual, fglob, rtol=1e-5)
assert_allclose(min(out_x), min(global_optimum), rtol=1e-3)
assert_allclose(max(out_x), max(global_optimum), rtol=1e-3)
assert 'global' in out.ampgo_msg


def test_ampgo_bounds(minimizer_Alpine02):
"""Test AMPGO algorithm with bounds."""
# change boundaries of parameters
pars_bounds = lmfit.Parameters()
pars_bounds.add_many(('x0', 1., True, 5.0, 15.0),
('x1', 1., True, 2.5, 7.5))

global_optimum = [7.91705268, 4.81584232]
fglob = -6.12950
out = minimizer_Alpine02.minimize(params=pars_bounds, method='ampgo')
assert 5.0 <= out.params['x0'].value <= 15.0
assert 2.5 <= out.params['x1'].value <= 7.5

def residual_Alpine02(params):
x0 = params['x0'].value
x1 = params['x1'].value
return np.prod(np.sqrt(x0) * np.sin(x0)) * np.prod(np.sqrt(x1) *
np.sin(x1))

pars = lmfit.Parameters()
pars.add_many(('x0', 1., True, 0.0, 10.0),
('x1', 1., True, 0.0, 10.0))
def test_ampgo_disp_true(minimizer_Alpine02, capsys):
"""Test AMPGO algorithm with disp is True."""
# disp to False for L-BFGS-B to avoid too much output...
kws = {'disp': True, 'local_opts': {'disp': False}}
minimizer_Alpine02.minimize(method='ampgo', **kws)
captured = capsys.readouterr()
assert "Starting MINIMIZATION Phase" in captured.out

mini = lmfit.Minimizer(residual_Alpine02, pars)
out = mini.minimize(method='ampgo')

def test_ampgo_maxfunevals(minimizer_Alpine02):
"""Test AMPGO algorithm with maxfunevals."""
# disp to False for L-BFGS-B to avoid too much output...
kws = {'maxfunevals': 5, 'disp': True, 'local_opts': {'disp': False}}
out = minimizer_Alpine02.minimize(method='ampgo', **kws)

assert out.ampgo_msg == 'Maximum number of function evaluations exceeded'


def test_ampgo_local_solver(minimizer_Alpine02):
"""Test AMPGO algorithm with local solver."""
kws = {'local': 'Nelder-Mead'}
out = minimizer_Alpine02.minimize(method='ampgo', **kws)
out_x = np.array([out.params['x0'].value, out.params['x1'].value])

assert 'ampgo' and 'Nelder-Mead' in out.method
assert_allclose(out.residual, fglob, rtol=1e-5)
assert_allclose(min(out_x), min(global_optimum), rtol=1e-3)
assert_allclose(max(out_x), max(global_optimum), rtol=1e-3)
assert('global' in out.ampgo_msg)
assert 'global' in out.ampgo_msg


def test_ampgo_Alpine02_maxfunevals():
"""Test AMPGO algorithm on Alpine02 function."""
def test_ampgo_invalid_local_solver(minimizer_Alpine02):
"""Test AMPGO algorithm with invalid local solvers."""
kws = {'local': 'leastsq'}
with pytest.raises(Exception, match=r'Invalid local solver selected'):
minimizer_Alpine02.minimize(method='ampgo', **kws)


def test_ampgo_invalid_tabulistsize(minimizer_Alpine02):
"""Test AMPGO algorithm with invalid tabulistsize."""
kws = {'tabulistsize': 0}
with pytest.raises(Exception, match=r'Invalid tabulistsize specified'):
minimizer_Alpine02.minimize(method='ampgo', **kws)


def test_ampgo_invalid_tabustrategy(minimizer_Alpine02):
"""Test AMPGO algorithm with invalid tabustrategy."""
kws = {'tabustrategy': 'unknown'}
with pytest.raises(Exception, match=r'Invalid tabustrategy specified'):
minimizer_Alpine02.minimize(method='ampgo', **kws)

def residual_Alpine02(params):
x0 = params['x0'].value
x1 = params['x1'].value
return np.prod(np.sqrt(x0) * np.sin(x0)) * np.prod(np.sqrt(x1) *
np.sin(x1))

pars = lmfit.Parameters()
pars.add_many(('x0', 1., True, 0.0, 10.0),
('x1', 1., True, 0.0, 10.0))
@pytest.mark.skipif(sys.version_info.major == 2,
reason="does not throw an exception in Python 2")
def test_ampgo_local_opts(minimizer_Alpine02):
"""Test AMPGO algorithm, pass local_opts to solver."""
# use local_opts to pass maxiter to the local optimizer: providing a string
# whereas an integer is required, this should throw an error.
kws = {'local_opts': {'maxiter': 'string'}}
with pytest.raises(TypeError):
minimizer_Alpine02.minimize(method='ampgo', **kws)

mini = lmfit.Minimizer(residual_Alpine02, pars)
kws = {'maxfunevals': 50}
out = mini.minimize(method='ampgo', **kws)
assert('function' in out.ampgo_msg)
# for coverage: make sure that both occurences are reached
kws = {'local_opts': {'maxiter': 10}, 'maxfunevals': 50}
minimizer_Alpine02.minimize(method='ampgo', **kws)
Loading

0 comments on commit 1268177

Please sign in to comment.