Skip to content

Commit

Permalink
Apply c-bata's comments
Browse files Browse the repository at this point in the history
  • Loading branch information
nabenabe0928 committed Oct 30, 2024
1 parent d37f6c9 commit 13e63a6
Showing 1 changed file with 0 additions and 57 deletions.
57 changes: 0 additions & 57 deletions package/samplers/auto_sampler/tests/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,63 +77,6 @@
)


"""
@pytest.mark.parametrize(
# TODO: 知らん
"sampler_class,expected_has_rng,expected_has_another_sampler", [(AutoSampler, ..., ...)]
)
def test_sampler_reseed_rng(
sampler_class: Callable[[], BaseSampler],
expected_has_rng: bool,
expected_has_another_sampler: bool,
) -> None:
def _extract_attr_name_from_sampler_by_cls(sampler: BaseSampler, cls: Any) -> str | None:
for name, attr in sampler.__dict__.items():
if isinstance(attr, cls):
return name
return None
sampler = sampler_class()
rng_name = _extract_attr_name_from_sampler_by_cls(sampler, LazyRandomState)
has_rng = rng_name is not None
assert expected_has_rng == has_rng
if has_rng:
rng_name = str(rng_name)
original_random_state = sampler.__dict__[rng_name].rng.get_state()
sampler.reseed_rng()
random_state = sampler.__dict__[rng_name].rng.get_state()
if not isinstance(sampler, optuna.samplers.CmaEsSampler):
assert str(original_random_state) != str(random_state)
else:
# CmaEsSampler has a RandomState that is not reseed by its reseed_rng method.
assert str(original_random_state) == str(random_state)
had_sampler_name = _extract_attr_name_from_sampler_by_cls(sampler, BaseSampler)
has_another_sampler = had_sampler_name is not None
assert expected_has_another_sampler == has_another_sampler
if has_another_sampler:
had_sampler_name = str(had_sampler_name)
had_sampler = sampler.__dict__[had_sampler_name]
had_sampler_rng_name = _extract_attr_name_from_sampler_by_cls(had_sampler, LazyRandomState)
original_had_sampler_random_state = had_sampler.__dict__[
had_sampler_rng_name
].rng.get_state()
with patch.object(
had_sampler,
"reseed_rng",
wraps=had_sampler.reseed_rng,
) as mock_object:
sampler.reseed_rng()
assert mock_object.call_count == 1
had_sampler = sampler.__dict__[had_sampler_name]
had_sampler_random_state = had_sampler.__dict__[had_sampler_rng_name].rng.get_state()
assert str(original_had_sampler_random_state) != str(had_sampler_random_state)
"""


def parametrize_suggest_method(name: str) -> MarkDecorator:
return pytest.mark.parametrize(
f"suggest_method_{name}",
Expand Down

0 comments on commit 13e63a6

Please sign in to comment.