Skip to content

Commit

Permalink
Merge pull request #208 from nabenabe0928/test/add-tests-for-ctpe
Browse files Browse the repository at this point in the history
Add tests for c-TPE
  • Loading branch information
c-bata authored Dec 17, 2024
2 parents 2512d76 + d59b7ad commit a267f58
Show file tree
Hide file tree
Showing 3 changed files with 929 additions and 3 deletions.
12 changes: 12 additions & 0 deletions package/samplers/ctpe/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,15 @@ When you use this sampler, please cite the following:
year={2022}
}
```

### Test

To execute the tests for `cTPESampler`, please run the following commands. The test file is provided in the package.

```sh
pip install pytest
```

```python
pytest package/samplers/ctpe/tests/
```
8 changes: 5 additions & 3 deletions package/samplers/ctpe/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ def __init__(
)

def _warning_multi_objective_for_ctpe(self, study: Study) -> None:
"""TODO: Use this routine once c-TPE supports multi-objective optimization.
if study._is_multi_objective():

def _get_additional_msg() -> str:
beta = getattr(self._gamma, "_beta", None)
strategy = getattr(self._gamma, "_strategy", None)
Expand All @@ -98,6 +98,8 @@ def _get_additional_msg() -> str:
"but sampling will be performed by c-TPE based on Optuna MOTPE. "
f"{_get_additional_msg()}"
)
"""
self._raise_error_if_multi_objective(study)

def _build_parzen_estimators_for_constraints_and_get_quantiles(
self,
Expand All @@ -123,7 +125,7 @@ def _build_parzen_estimators_for_constraints_and_get_quantiles(
study, search_space, unsatisfied_trials, handle_below=False
)
)
quantiles.append(len(satisfied_trials) / len(trials))
quantiles.append(len(satisfied_trials) / max(1, len(trials)))

return mpes_below, mpes_above, quantiles

Expand All @@ -149,7 +151,7 @@ def _sample(
mpes_above.append(
self._build_parzen_estimator(study, search_space, above_trials, handle_below=False)
)
quantiles.append(len(below_trials) / len(trials))
quantiles.append(len(below_trials) / max(1, len(trials)))

_samples_below: dict[str, list[np.ndarray]] = {
param_name: [] for param_name in search_space
Expand Down
Loading

0 comments on commit a267f58

Please sign in to comment.