From d4ae9dd34da60ee98324c2e6e9d87c9167cde347 Mon Sep 17 00:00:00 2001 From: Christopher Fonnesbeck Date: Mon, 17 Aug 2015 13:55:38 -0500 Subject: [PATCH 1/3] Generalizes utils.invcdf to accept n-dim arrays, addressing #53 --- pymc/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pymc/utils.py b/pymc/utils.py index 2ded27ff..bac9c33d 100644 --- a/pymc/utils.py +++ b/pymc/utils.py @@ -445,8 +445,9 @@ def lognormcdf(x, mu, tau): def invcdf(x): """Inverse of normal cumulative density function.""" - x = np.atleast_1d(x) - return np.array([flib.ppnd16(y, 1) for y in x]) + x_flat = np.ravel(x) + x_trans = np.array([flib.ppnd16(y, 1) for y in x_flat]) + return np.reshape(x_trans, np.shape(x)) def ar1_gen(rho, mu, sigma, size=1): From 2d57b68b94f9d06095e225e1a0322769b6822ba7 Mon Sep 17 00:00:00 2001 From: Christopher Fonnesbeck Date: Mon, 17 Aug 2015 14:00:11 -0500 Subject: [PATCH 2/3] Added unit test for invcdf shape --- pymc/tests/test_utils.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/pymc/tests/test_utils.py b/pymc/tests/test_utils.py index 43828295..20e05c22 100644 --- a/pymc/tests/test_utils.py +++ b/pymc/tests/test_utils.py @@ -8,7 +8,6 @@ from pymc import six xrange = six.moves.xrange - class test_logp_of_set(TestCase): A = Normal('A', 0, 1) B = Gamma('B', 1, 1) @@ -73,6 +72,19 @@ def test_normcdf_log_3d_input(self): x = arange(8.).reshape(2, 2, 2) utils.normcdf(x, log=True) +class test_invcdf_input_shape(TestCase): + + def test_invcdf_1d_input(self): + x = random.random(8) + utils.invcdf(x) + + def test_normcdf_2d_input(self): + x = random.random((2, 4)) + utils.invcdf(x) + + def test_normcdf_3d_input(self): + x = arange.random((2, 2, 2)) + utils.invcdf(x) if __name__ == '__main__': C = nose.config.Config(verbosity=1) From 8104089c2b8c352cad76670430ac0e53d7b1de8d Mon Sep 17 00:00:00 2001 From: Christopher Fonnesbeck Date: Mon, 17 Aug 2015 14:07:26 -0500 Subject: [PATCH 3/3] Bug fix for invcdf unit test --- pymc/tests/test_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pymc/tests/test_utils.py b/pymc/tests/test_utils.py index 20e05c22..2389e88e 100644 --- a/pymc/tests/test_utils.py +++ b/pymc/tests/test_utils.py @@ -78,12 +78,12 @@ def test_invcdf_1d_input(self): x = random.random(8) utils.invcdf(x) - def test_normcdf_2d_input(self): + def test_invcdf_2d_input(self): x = random.random((2, 4)) utils.invcdf(x) - def test_normcdf_3d_input(self): - x = arange.random((2, 2, 2)) + def test_invcdf_3d_input(self): + x = random.random((2, 2, 2)) utils.invcdf(x) if __name__ == '__main__':