diff --git a/paramnormal/dist.py b/paramnormal/dist.py index b165eef..967ada2 100644 --- a/paramnormal/dist.py +++ b/paramnormal/dist.py @@ -797,6 +797,7 @@ class exponential(BaseDist_Mixin): numpy.random.exponential """ + dist = stats.expon param_template = namedtuple('params', ['lamda', 'loc']) name = 'exponential' @@ -833,7 +834,7 @@ class rice(BaseDist_Mixin): R : float The shape parameter of the distribution. sigma : float - The standard deviate of the distribution. + The standard deviation of the distribution. loc : float, optional Location parameter of the distribution. This defaults to, and should probably be left at, 0. @@ -879,6 +880,7 @@ class rice(BaseDist_Mixin): numpy.random.exponential """ + dist = stats.rice param_template = namedtuple('params', ['R', 'sigma', 'loc']) name = 'rice' @@ -904,6 +906,114 @@ def fit(cls, data, **guesses): return cls.param_template(R=b*sigma, loc=loc, sigma=sigma) +class truncated_normal(BaseDist_Mixin): + """ + Create and fit data to a truncated normal distribution. + + Methods + ------- + fit + Use scipy's maximum likelihood estimation methods to estimate + the parameters of the data's distribution. + from_params + Create a new distribution instances from the ``namedtuple`` + result of the :meth:`~fit` method. + + Parameters + ---------- + lower, upper : float + The lower and upper limits of the distribution that serve as its + shape parameters. + mu : float, optional (default = 0) + The expected value (mean) of the underlying normal distribution. + Acts as the location parameter of the distribution. + sigma : float, optional (default = 1) + The standard deviation of the underlying normal distribution. + Also acts as the scale parameter of distribution. + + Examples + -------- + >>> import numpy + >>> import paramnormal as pn + >>> numpy.random.seed(0) + >>> pn.truncated_normal(lower=-0.5, upper=0.5).rvs(size=3) + array([ 0.04687082, 0.20804061, 0.09879796]) + + >>> # you can also use greek letters + >>> numpy.random.seed(0) + >>> pn.truncated_normal(lower=-0.5, upper=2.5, σ=2).rvs(size=3) + array([ 0.8902748 , 1.37377049, 1.04012565]) + + >>> # silly fake data + >>> numpy.random.seed(0) + >>> data = pn.truncated_normal(lower=-0.5, upper=2.5, mu=0, sigma=2).rvs(size=37) + >>> # pretend `data` is unknown and we want to fit a dist. to it + >>> pn.truncated_normal.fit(data) + params(lower=1.040124, upper=1.082447, mu=-8.097877e-06, sigma=1.033405) + + In scipy, the distribution is defined as + ``stats.truncnorm(a, b, loc, scale)`` where + + .. math:: + + a = \frac{\mathrm{lower bound}} - \mu}{\sigma} + + and + + .. math:: + + b = \frac{x_{\mathrm{upper bound}} - \mu}{\sigma} + + Since ``a`` and ``b`` are directly linked to the location and scale + of the distribution as well as the lower and upper limits, + respectively, it's difficult to use the ``fit`` method of this + distirbution without either knowing a lot about it `a priori` or + assuming just as much. + + References + ---------- + http://scipy.github.io/devdocs/generated/scipy.stats.truncnorm + https://en.wikipedia.org/wiki/Rice_distribution + + See Also + -------- + scipy.stats.rice + numpy.random.exponential + + """ + + dist = stats.truncnorm + param_template = namedtuple('params', ['lower', 'upper', 'mu', 'sigma']) + name = 'truncated normal' + + @staticmethod + @utils.greco_deco + def _process_args(lower=None, upper=None, mu=None, sigma=None, fit=False): + a = None + b = None + if lower is not None and mu is not None and sigma is not None: + a = (lower - mu) / sigma + + if upper is not None and mu is not None and sigma is not None: + b = (upper - mu) / sigma + + loc_key, scale_key = utils._get_loc_scale_keys(fit=fit) + if fit: + akey = 'f0' + bkey = 'f1' + else: + akey = 'a' + bkey = 'b' + return {akey: a, bkey: b, loc_key: mu, scale_key: sigma} + + @classmethod + def fit(cls, data, **guesses): + a, b, mu, sigma = cls._fit(data, **guesses) + lower = a * sigma + mu + upper = b * sigma + mu + return cls.param_template(lower=lower, upper=upper, mu=mu, sigma=sigma) + + __all__ = [ 'normal', 'lognormal', @@ -915,4 +1025,5 @@ def fit(cls, data, **guesses): 'pareto', 'exponential', 'rice', + 'truncated_normal', ] diff --git a/paramnormal/tests/test_dist.py b/paramnormal/tests/test_dist.py index 3dab819..04f986e 100644 --- a/paramnormal/tests/test_dist.py +++ b/paramnormal/tests/test_dist.py @@ -377,3 +377,35 @@ def test_fit(self): (params.sigma, 1.759817171541185), (params.loc, 0), ) + + +class Test_truncated_normal(CheckDist_Mixin): + def setup(self): + self.dist = dist.truncated_normal + self.cargs = [] + self.ckwds = dict(lower=-0.5, upper=2.5, mu=1, sigma=4) + + self.np_rand_fxn = stats.truncnorm.rvs + self.npargs = [-0.375, 0.375] + self.npkwds = dict(loc=1, scale=4) + + def test_processargs(self): + result = self.dist._process_args(lower=-0.5, upper=2.5, mu=1, sigma=4) + expected = dict(a=-0.375, b=0.375, loc=1, scale=4) + assert result == expected + + result = self.dist._process_args(upper=2.5, mu=1, sigma=4, fit=True) + expected = dict(f0=None, f1=0.375, floc=1, fscale=4) + assert result == expected + + @seed + def test_fit(self): + stn = stats.truncnorm(-0.375, 0.375, loc=1, scale=4) + data = stn.rvs(size=37000) + params = self.dist.fit(data, lower=-0.5, mu=1, sigma=4) + check_params( + (params.lower, -0.5), + (params.upper, 2.4999301), + (params.mu, 1), + (params.sigma, 4), + )