Skip to content

Commit

Permalink
Merge pull request #546 from lmfit/fix_modelresult_loads
Browse files Browse the repository at this point in the history
add MinimizerResult when loading ModelResult
  • Loading branch information
newville authored Mar 26, 2019
2 parents 24eecd1 + e1f7197 commit d1ed894
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 7 deletions.
9 changes: 8 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ install:
- if [[ $version == latest && $TRAVIS_PYTHON_VERSION == 3.7-dev ]]; then conda create -q -n test_env python=3.7 numpy scipy six pandas matplotlib dill sphinx; fi
- source activate test_env
- pip install pytest
- pip install codecov
- pip install asteval
- pip install uncertainties
- if [[ $version == latest ]]; then pip install emcee numdifftools; fi
Expand All @@ -41,4 +42,10 @@ install:
script:
- cd tests
- pytest
- if [[ $version == latest && $TRAVIS_PYTHON_VERSION == 3.7-dev ]]; then cd ../doc ; make ; fi # test building the documentation
- coverage run --source=lmfit -m pytest
- coverage report -m
# test building the documentation
- if [[ $version == latest && $TRAVIS_PYTHON_VERSION == 3.7-dev ]]; then cd ../doc ; make ; fi

after_success:
- codecov
3 changes: 3 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ LMfit-py
.. image:: https://travis-ci.org/lmfit/lmfit-py.svg
:target: https://travis-ci.org/lmfit/lmfit-py

.. image:: https://codecov.io/gh/lmfit/lmfit-py/branch/master/graph/badge.svg
:target: https://codecov.io/gh/lmfit/lmfit-py

.. image:: https://img.shields.io/pypi/v/lmfit.svg
:target: https://pypi.org/project/lmfit

Expand Down
12 changes: 8 additions & 4 deletions lmfit/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from . import Minimizer, Parameter, Parameters, lineshapes
from .confidence import conf_interval
from .jsonutils import HAS_DILL, decode4js, encode4js
from .minimizer import validate_nan_policy
from .minimizer import validate_nan_policy, MinimizerResult
from .printfuncs import ci_report, fit_report

# Use pandas.isnull for aligning missing data if pandas is available.
Expand Down Expand Up @@ -1259,9 +1259,6 @@ def load_modelresult(fname, funcdefs=None):
modres = ModelResult(Model(lambda x: x, None), params)
with open(fname) as fh:
mresult = modres.load(fh, funcdefs=funcdefs)
mresult.data = mresult.userargs[0]
mresult.weights = mresult.userargs[1]
mresult.init_params = mresult.model.make_params(**mresult.init_values)
return mresult


Expand Down Expand Up @@ -1682,6 +1679,13 @@ def loads(self, s, funcdefs=None, **kws):
setattr(self, attr, decode4js(modres.get(attr, None)))

self.best_fit = self.model.eval(self.params, **self.userkws)
if len(self.userargs) == 2:
self.data = self.userargs[0]
self.weights = self.userargs[1]
self.init_params = self.model.make_params(**self.init_values)
self.result = MinimizerResult()
self.result.params = self.params
self.init_vals = list(self.init_values.items())
return self

def load(self, fp, funcdefs=None, **kws):
Expand Down
2 changes: 1 addition & 1 deletion lmfit/printfuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def gformat(val, length=11):
if expon > 0:
prec -= expon
fmt = '{0: %i.%i%s}' % (length, prec, form)
return fmt.format(val)
return fmt.format(val)[:length]


CORREL_HEAD = '[[Correlations]] (unreported correlations are < %.3f)'
Expand Down
29 changes: 28 additions & 1 deletion tests/test_saveload.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
import pytest

import lmfit.jsonutils
from lmfit import Parameters
from lmfit.model import (load_model, load_modelresult, save_model,
save_modelresult)
save_modelresult, Model, ModelResult)
from lmfit.models import ExponentialModel, GaussianModel
from lmfit_testutils import assert_between, assert_param_between

Expand Down Expand Up @@ -181,3 +182,29 @@ def test_saveload_modelresult_exception():
with pytest.raises(AttributeError, match=r'needs saved ModelResult'):
load_modelresult(SAVE_MODEL)
clear_savefile(SAVE_MODEL)


def test_saveload_modelresult_roundtrip():
"""Test for modelresult.loads()/dumps() and repeating that"""
def mfunc(x, a, b):
return a * (x-b)

model = Model(mfunc)
params = model.make_params(a=0.0, b=3.0)

xx = np.linspace(-5, 5, 201)
yy = 0.5 * (xx - 0.22) + np.random.normal(scale=0.01, size=len(xx))

result1 = model.fit(yy, params, x=xx)

result2 = ModelResult(model, Parameters())
result2.loads(result1.dumps(), funcdefs={'mfunc': mfunc})

result3 = ModelResult(model, Parameters())
result3.loads(result2.dumps(), funcdefs={'mfunc': mfunc})

assert result3 is not None
assert_param_between(result2.params['a'], 0.48, 0.52)
assert_param_between(result2.params['b'], 0.20, 0.25)
assert_param_between(result3.params['a'], 0.48, 0.52)
assert_param_between(result3.params['b'], 0.20, 0.25)

0 comments on commit d1ed894

Please sign in to comment.