Skip to content

Commit

Permalink
Validate that an MO acqf is used for MOO in MBM/Acquistion
Browse files Browse the repository at this point in the history
Summary:
This diff adds a validation that botorch_acqf_class is an MO acqf when `TorchOptConfig.is_moo is True`. This should eliminate bugs like facebook#2519, which can happen since the downstream code will otherwise assume SOO.

Note that this only solves MBM side of the bug. Legacy code will still have the buggy behavior.

Differential Revision: D64563992
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Oct 17, 2024
1 parent 8392eec commit 9c194f5
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 16 deletions.
26 changes: 21 additions & 5 deletions ax/models/torch/botorch_modular/acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import torch
from ax.core.search_space import SearchSpaceDigest
from ax.exceptions.core import AxWarning, SearchSpaceExhausted
from ax.exceptions.core import AxWarning, SearchSpaceExhausted, UserInputError
from ax.models.model_utils import enumerate_discrete_combinations, mk_discrete_choices
from ax.models.torch.botorch_modular.optimizer_argparse import optimizer_argparse
from ax.models.torch.botorch_modular.surrogate import Surrogate
Expand All @@ -41,6 +41,10 @@
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.input_constructors import get_acqf_input_constructor
from botorch.acquisition.knowledge_gradient import qKnowledgeGradient
from botorch.acquisition.multi_objective.base import (
MultiObjectiveAnalyticAcquisitionFunction,
MultiObjectiveMCAcquisitionFunction,
)
from botorch.acquisition.objective import MCAcquisitionObjective, PosteriorTransform
from botorch.acquisition.risk_measures import RiskMeasureMCObjective
from botorch.models.model import Model, ModelDict
Expand Down Expand Up @@ -101,6 +105,19 @@ def __init__(
self.surrogates = surrogates
self.options = options or {}

if torch_opt_config.is_moo and not issubclass(
botorch_acqf_class,
(
MultiObjectiveAnalyticAcquisitionFunction,
MultiObjectiveMCAcquisitionFunction,
),
):
raise UserInputError(
"Acquisition requires a `MultiObjectiveAnalyticAcquisitionFunction` "
"or a `MultiObjectiveMCAcquisitionFunction` class when there are "
f"multiple objectives. Received {botorch_acqf_class=}."
)

# Compute pending and observed points for each surrogate
Xs_pending_and_observed = {
name: _get_X_pending_and_observed(
Expand Down Expand Up @@ -215,12 +232,11 @@ def __init__(
outcome_constraints = torch_opt_config.outcome_constraints
objective_thresholds = torch_opt_config.objective_thresholds
subset_idcs = None
# If objective weights suggest multiple objectives but objective
# thresholds are not specified, infer them using the model that
# has already been subset to avoid re-subsetting it within
# If MOO and some objective thresholds are not specified, infer them using
# the model that has already been subset to avoid re-subsetting it within
# `inter_objective_thresholds`.
if (
objective_weights.nonzero().numel() > 1
torch_opt_config.is_moo
and (
self._objective_thresholds is None
or self._objective_thresholds[torch_opt_config.objective_weights != 0]
Expand Down
45 changes: 34 additions & 11 deletions ax/models/torch/tests/test_acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import numpy as np
import torch
from ax.core.search_space import SearchSpaceDigest
from ax.exceptions.core import AxWarning, SearchSpaceExhausted
from ax.exceptions.core import AxWarning, SearchSpaceExhausted, UserInputError
from ax.models.torch.botorch_modular.acquisition import Acquisition
from ax.models.torch.botorch_modular.optimizer_argparse import optimizer_argparse
from ax.models.torch.botorch_modular.surrogate import Surrogate
Expand All @@ -43,6 +43,9 @@
)
from botorch.acquisition.knowledge_gradient import qKnowledgeGradient
from botorch.acquisition.monte_carlo import qNoisyExpectedImprovement
from botorch.acquisition.multi_objective.base import (
MultiObjectiveAnalyticAcquisitionFunction,
)
from botorch.acquisition.multi_objective.monte_carlo import (
qNoisyExpectedHypervolumeImprovement,
)
Expand Down Expand Up @@ -89,22 +92,30 @@ def evaluate(self, X: Tensor, **kwargs: Any) -> Tensor:
return X.sum(dim=-1)


class DummyMultiObjectiveAcquisitionFunction(
DummyAcquisitionFunction, MultiObjectiveAnalyticAcquisitionFunction
):
# Dummy acquisition function for testing multi-objective setup.
...


class AcquisitionTest(TestCase):
def setUp(self) -> None:
super().setUp()
qNEI_input_constructor = get_acqf_input_constructor(qNoisyExpectedImprovement)
# Adding wrapping here to be able to count calls and inspect arguments.
self.mock_input_constructor = mock.MagicMock(
qNEI_input_constructor, side_effect=qNEI_input_constructor
)
# Adding wrapping here to be able to count calls and inspect arguments.
_register_acqf_input_constructor(
acqf_cls=DummyAcquisitionFunction,
input_constructor=self.mock_input_constructor,
)
_register_acqf_input_constructor(
acqf_cls=DummyOneShotAcquisitionFunction,
input_constructor=self.mock_input_constructor,
)
for acqf_class in (
DummyAcquisitionFunction,
DummyOneShotAcquisitionFunction,
DummyMultiObjectiveAcquisitionFunction,
):
_register_acqf_input_constructor(
acqf_cls=acqf_class,
input_constructor=self.mock_input_constructor,
)
tkwargs: dict[str, Any] = {"dtype": torch.double}
self.botorch_model_class = SingleTaskGP
self.surrogate = Surrogate(botorch_model_class=self.botorch_model_class)
Expand Down Expand Up @@ -738,7 +749,7 @@ def test_init_moo(
with_outcome_constraints: bool = True,
) -> None:
acqf_class = (
DummyAcquisitionFunction
DummyMultiObjectiveAcquisitionFunction
if with_no_X_observed
else qNoisyExpectedHypervolumeImprovement
)
Expand Down Expand Up @@ -776,7 +787,19 @@ def test_init_moo(
objective_weights=moo_objective_weights,
outcome_constraints=outcome_constraints,
objective_thresholds=moo_objective_thresholds,
is_moo=True,
)
with self.assertRaisesRegex(
UserInputError, "when there are multiple objectives"
):
Acquisition(
surrogates={"surrogate": self.surrogate},
botorch_acqf_class=DummyAcquisitionFunction,
search_space_digest=self.search_space_digest,
torch_opt_config=torch_opt_config,
options=self.options,
)

acquisition = Acquisition(
surrogates={"surrogate": self.surrogate},
botorch_acqf_class=acqf_class,
Expand Down

0 comments on commit 9c194f5

Please sign in to comment.