Skip to content

Commit

Permalink
Merge pull request #586 from reneeotten/nan_policy_message
Browse files Browse the repository at this point in the history
Nan policy message
  • Loading branch information
newville authored Aug 29, 2019
2 parents 6478b7d + d76a546 commit d7b408b
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 6 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ exclude: 'versioneer.py|lmfit/_version|doc/conf.py'

repos:
- repo: https://github.com/asottile/pyupgrade
rev: v1.22.1
rev: v1.23.0
hooks:
- id: pyupgrade
# for now don't force to change from %-operator to {}
Expand Down
6 changes: 4 additions & 2 deletions doc/whatsnew.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,10 @@ Various:
(replaced by omit) deprecated since 0.9 (PR #565).
- deprecate 'report_errors' in printfuncs.py (PR #571)
- updates to the documentation to use ``jupyter-sphinx`` to include examples/output (PRs #573 and #575)
- include a Gallery with examples in the documentation using ``sphinx-gallery`` (PR #574)
- improve test-coverage (PRs #571 and #572)
- include a Gallery with examples in the documentation using ``sphinx-gallery`` (PR #574 and #583)
- improve test-coverage (PRs #571, #572 and #585)
- add/clarify warning messages when NaN values are detected (PR #586)
- several updates to docstrings (Issue #584; PR #583, and others)
- update pre-commit hooks and several docstrings

.. _whatsnew_0913_label:
Expand Down
6 changes: 5 additions & 1 deletion lmfit/minimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2305,7 +2305,11 @@ def _nan_policy(arr, nan_policy='raise', handle_inf=True):
"NaNs will be ignored.", RuntimeWarning)

if contains_nan:
raise ValueError("The input contains nan values")
msg = ('NaN values detected in your input data or the output of '
'your objective/model function - fitting algorithms cannot '
'handle this! Please read https://lmfit.github.io/lmfit-py/faq.html#i-get-errors-from-nan-in-my-fit-what-can-i-do '
'for more information.')
raise ValueError(msg)
return arr


Expand Down
10 changes: 9 additions & 1 deletion lmfit/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,15 @@ def _residual(self, params, data, weights, **kwargs):
The "ravels" throughout are necessary to support pandas.Series.
"""
diff = self.eval(params, **kwargs) - data
model = self.eval(params, **kwargs)
if self.nan_policy == 'raise' and not np.all(np.isfinite(model)):
msg = ('The model function generated NaN values and the fit '
'aborted! Please check your model function and/or set '
'boundaries on parameters where applicable. In cases like '
'this, using "nan_policy=\'omit\'" will probably not work.')
raise ValueError(msg)

diff = model - data

if diff.dtype == np.complex:
# data/model are complex
Expand Down
25 changes: 24 additions & 1 deletion tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,13 +622,15 @@ def gauss(x, sigma, mu, A):
self.assertTrue(result.params['pos_delta'].value > 0)

def test_model_nan_policy(self):
"""Tests for nan_policy with NaN values in the input data."""
x = np.linspace(0, 10, 201)
np.random.seed(0)
y = gaussian(x, 10.0, 6.15, 0.8)
y += gaussian(x, 8.0, 6.35, 1.1)
y += gaussian(x, 0.25, 6.00, 7.5)
y += np.random.normal(size=len(x), scale=0.5)

# with NaN values in the input data
y[55] = y[91] = np.nan
mod = PseudoVoigtModel()
params = mod.make_params(amplitude=20, center=5.5,
Expand All @@ -637,7 +639,9 @@ def test_model_nan_policy(self):

# with raise, should get a ValueError
result = lambda: mod.fit(y, params, x=x, nan_policy='raise')
self.assertRaises(ValueError, result)
msg = ('NaN values detected in your input data or the output of your '
'objective/model function - fitting algorithms cannot handle this!')
self.assertRaisesRegexp(ValueError, msg, result)

# with propagate, should get no error, but bad results
result = mod.fit(y, params, x=x, nan_policy='propagate')
Expand All @@ -663,6 +667,25 @@ def test_model_nan_policy(self):
with pytest.raises(ValueError, match=err_msg):
mod.fit(y, params, x=x, nan_policy='wrong_argument')

def test_model_nan_policy_NaNs_by_model(self):
"""Test for nan_policy with NaN values generated by the model function."""
def double_exp(x, a1, t1, a2, t2):
return a1*np.exp(-x/t1) + a2*np.exp(-(x-0.1) / t2)

model = Model(double_exp)

truths = (3.0, 2.0, -5.0, 10.0)
x = np.linspace(1, 10, 250)
np.random.seed(0)
y = double_exp(x, *truths) + 0.1*np.random.randn(x.size)

p = model.make_params(a1=4, t1=3, a2=4, t2=3)
result = lambda: model.fit(data=y, params=p, x=x, method='Nelder',
nan_policy='raise')

msg = 'The model function generated NaN values and the fit aborted!'
self.assertRaisesRegexp(ValueError, msg, result)

@pytest.mark.skipif(sys.version_info.major == 2,
reason="cannot use wrapped functions with Python 2")
def test_wrapped_model_func(self):
Expand Down

0 comments on commit d7b408b

Please sign in to comment.