Skip to content

Commit

Permalink
Implement test for user-defined mask in halo finder.
Browse files Browse the repository at this point in the history
  • Loading branch information
kyleaoman committed Oct 11, 2023
1 parent 929f3e6 commit ca3cf5b
Showing 1 changed file with 29 additions and 50 deletions.
79 changes: 29 additions & 50 deletions tests/test_halo_finders.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,56 +95,35 @@ def test_get_void_extra_mask(self, hf):
assert getattr(generated_extra_mask, particle_type) is None
remove_toysnap()

# @pytest.mark.parametrize(
# "extra_mask, expected",
# (
# (
# "bound_only",
# dict(gas=n_g, dark_matter=n_dm, stars=n_s, black_holes=n_bh),
# ),
# (None, dict(gas=None, dark_matter=None, stars=None, black_holes=None)),
# (
# MaskCollection(
# gas=np.r_[
# np.ones(100, dtype=bool), np.zeros(n_g_all - 100, dtype=bool)
# ],
# dark_matter=None,
# stars=np.r_[
# np.ones(100, dtype=bool), np.zeros(n_s - 100, dtype=bool)
# ],
# black_holes=np.ones(n_bh, dtype=bool),
# ),
# dict(gas=100, dark_matter=None, stars=100, black_holes=n_bh),
# ),
# ),
# )
# def test_get_user_extra_mask(self, sg, hf, extra_mask, expected):
# """
# Check that extra masks of different kinds have the right shape or type.
# """
# if hasattr(hf, "_caesar"):
# if hf.group_type == "galaxy" and extra_mask == "bound_only":
# expected["dark_matter"] = 0
# snap = h5py.File(sg.filename, "r")
# for k in expected.keys():
# if hasattr(expected[k], "shape"):
# expected[k] = expected[k][
# : snap[
# "Cells/Counts/"
# "PartType{dict(gas=0, dark_matter=1, stars=4, black_holes=5)[k]}"
# ][0]
# ]
# hf.extra_mask = extra_mask
# generated_extra_mask = hf._get_extra_mask(sg)
# print(generated_extra_mask.gas.shape)
# for particle_type in present_particle_types.values():
# if getattr(generated_extra_mask, particle_type) is not None:
# assert (
# getattr(generated_extra_mask, particle_type).sum()
# == expected[particle_type]
# )
# else:
# assert expected[particle_type] is None
def test_get_user_extra_mask(self, hf):
"""
Check that extra masks of different kinds have the right shape or type.
"""
hf.extra_mask = MaskCollection(
gas=np.r_[np.ones(100, dtype=bool), np.zeros(n_g_all - 100, dtype=bool)],
dark_matter=None,
stars=np.r_[np.ones(100, dtype=bool), np.zeros(n_s - 100, dtype=bool)],
black_holes=np.ones(n_bh, dtype=bool),
)
create_toysnap()
sg = SWIFTGalaxy(toysnap_filename, hf)
generated_extra_mask = sg._extra_mask
for particle_type in present_particle_types.values():
if getattr(generated_extra_mask, particle_type) is None:
assert (
dict(gas=100, dark_matter=None, stars=100, black_holes=n_bh)[
particle_type
]
is None
)
else:
assert (
getattr(generated_extra_mask, particle_type).sum()
== dict(gas=100, dark_matter=None, stars=100, black_holes=n_bh)[
particle_type
]
)
remove_toysnap()

def test_centre(self, hf):
"""
Expand Down

0 comments on commit ca3cf5b

Please sign in to comment.