diff --git a/tests/image_deconvolution/test_prior_base.py b/tests/image_deconvolution/test_prior_base.py new file mode 100644 index 00000000..60a08aa9 --- /dev/null +++ b/tests/image_deconvolution/test_prior_base.py @@ -0,0 +1,29 @@ +import pytest +import numpy as np + +from cosipy.image_deconvolution.prior_base import PriorBase + +def test_PriorBase(): + PriorBase.__abstractmethods__ = set() + + # no class is allowered + with pytest.raises(TypeError) as e_info: + coefficient = 10 + test_model = np.zeros(2) + prior = PriorBase(coefficient, test_model) + + # As a test, np.ndarray is added + PriorBase.usable_model_classes.append(np.ndarray) + + coefficient = 10 + test_model = np.zeros(2) + prior = PriorBase(coefficient, test_model) + + # other function tests + with pytest.raises(RuntimeError) as e_info: + prior.log_prior(test_model) + assert e_info.type is NotImplementedError + + with pytest.raises(RuntimeError) as e_info: + prior.grad_log_prior(test_model) + assert e_info.type is NotImplementedError diff --git a/tests/image_deconvolution/test_priors.py b/tests/image_deconvolution/test_priors.py new file mode 100644 index 00000000..2aa10ee7 --- /dev/null +++ b/tests/image_deconvolution/test_priors.py @@ -0,0 +1,36 @@ +import pytest + +import astropy.units as u +import numpy as np +import healpy as hp + +from cosipy.image_deconvolution.prior_tsv import PriorTSV +from cosipy.image_deconvolution import AllSkyImageModel + +def test_PriorTSV(): + + coefficient = 1.0 + + nside = 1 + allskyimage_model = AllSkyImageModel(nside = nside, + energy_edges = np.array([500.0, 510.0]) * u.keV) + allskyimage_model[:,0] = np.arange(hp.nside2npix(nside)) * allskyimage_model.unit + + prior_tsv = PriorTSV(coefficient, allskyimage_model) + + assert np.isclose(prior_tsv.log_prior(allskyimage_model), -1176.0) + + grad_log_prior_correct = np.array([[ 92.], + [ 76.], + [ 60.], + [ 28.], + [ 40.], + [ -8.], + [ -8.], + [ -24.], + [ -36.], + [ -52.], + [ -68.], + [-100.]]) * u.Unit('cm2 s sr') + + assert np.allclose(prior_tsv.grad_log_prior(allskyimage_model), grad_log_prior_correct)