Skip to content

Commit

Permalink
Added unit tests for prior classes
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyoneda committed Oct 29, 2024
1 parent c3f4283 commit 3df3468
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 0 deletions.
29 changes: 29 additions & 0 deletions tests/image_deconvolution/test_prior_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import pytest
import numpy as np

from cosipy.image_deconvolution.prior_base import PriorBase

def test_PriorBase():
PriorBase.__abstractmethods__ = set()

# no class is allowered
with pytest.raises(TypeError) as e_info:
coefficient = 10
test_model = np.zeros(2)
prior = PriorBase(coefficient, test_model)

# As a test, np.ndarray is added
PriorBase.usable_model_classes.append(np.ndarray)

coefficient = 10
test_model = np.zeros(2)
prior = PriorBase(coefficient, test_model)

# other function tests
with pytest.raises(RuntimeError) as e_info:
prior.log_prior(test_model)
assert e_info.type is NotImplementedError

with pytest.raises(RuntimeError) as e_info:
prior.grad_log_prior(test_model)
assert e_info.type is NotImplementedError
36 changes: 36 additions & 0 deletions tests/image_deconvolution/test_priors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import pytest

import astropy.units as u
import numpy as np
import healpy as hp

from cosipy.image_deconvolution.prior_tsv import PriorTSV
from cosipy.image_deconvolution import AllSkyImageModel

def test_PriorTSV():

coefficient = 1.0

nside = 1
allskyimage_model = AllSkyImageModel(nside = nside,
energy_edges = np.array([500.0, 510.0]) * u.keV)
allskyimage_model[:,0] = np.arange(hp.nside2npix(nside)) * allskyimage_model.unit

prior_tsv = PriorTSV(coefficient, allskyimage_model)

assert np.isclose(prior_tsv.log_prior(allskyimage_model), -1176.0)

grad_log_prior_correct = np.array([[ 92.],
[ 76.],
[ 60.],
[ 28.],
[ 40.],
[ -8.],
[ -8.],
[ -24.],
[ -36.],
[ -52.],
[ -68.],
[-100.]]) * u.Unit('cm2 s sr')

assert np.allclose(prior_tsv.grad_log_prior(allskyimage_model), grad_log_prior_correct)

0 comments on commit 3df3468

Please sign in to comment.