From 6f9100b58d352904f4534df45f9ff33c05f61a4d Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Fri, 30 Aug 2024 09:51:06 +0200 Subject: [PATCH 01/20] Expose surrogate and posterior via campaign --- baybe/campaign.py | 50 ++++++++++++++++++++ baybe/recommenders/meta/base.py | 83 ++++++++++++++++++++++++++------- baybe/surrogates/base.py | 19 +++++++- 3 files changed, 133 insertions(+), 19 deletions(-) diff --git a/baybe/campaign.py b/baybe/campaign.py index 00d08a1d0..e897d43ef 100644 --- a/baybe/campaign.py +++ b/baybe/campaign.py @@ -3,6 +3,7 @@ from __future__ import annotations import json +from typing import TYPE_CHECKING import cattrs import numpy as np @@ -14,7 +15,9 @@ from baybe.objectives.base import Objective, to_objective from baybe.parameters.base import Parameter from baybe.recommenders.base import RecommenderProtocol +from baybe.recommenders.meta.base import MetaRecommender from baybe.recommenders.meta.sequential import TwoPhaseMetaRecommender +from baybe.recommenders.pure.bayesian.base import BayesianRecommender from baybe.searchspace.core import ( SearchSpace, SearchSpaceType, @@ -22,6 +25,7 @@ validate_searchspace_from_config, ) from baybe.serialization import SerialMixin, converter +from baybe.surrogates.base import Surrogate from baybe.targets.base import Target from baybe.telemetry import ( TELEM_LABELS, @@ -31,6 +35,9 @@ from baybe.utils.boolean import eq_dataframe from baybe.utils.plotting import to_string +if TYPE_CHECKING: + from botorch.posteriors import Posterior + @define class Campaign(SerialMixin): @@ -269,6 +276,49 @@ def recommend( return rec + def posterior(self, candidates: pd.DataFrame) -> Posterior: + """Get the posterior predictive distribution for the given candidates. + + The predictive distribution is based on the surrogate model of the last used + recommender. + + Args: + candidates: The candidate points in experimental recommendations. + For details, see :meth:`baybe.surrogates.base.Surrogate.posterior`. + + Returns: + Posterior: The corresponding posterior object. + For details, see :meth:`baybe.surrogates.base.Surrogate.posterior`. + """ + surrogate = self.get_surrogate() + return surrogate.posterior(candidates) + + def get_surrogate(self) -> Surrogate: + """Get the current surrogate model. + + Raises: + RuntimeError: If the current recommender does not provide a surrogate model. + + Returns: + Surrogate: The surrogate of the current recommender. + """ + if isinstance(self.recommender, MetaRecommender): + pure_recommender = self.recommender.get_current_recommender() + else: + pure_recommender = self.recommender + + if isinstance(pure_recommender, BayesianRecommender): + surrogate = pure_recommender.surrogate_model + surrogate.fit(self.searchspace, self.objective, self.measurements) + return surrogate + else: + raise RuntimeError( + f"The current recommender is of type " + f"'{pure_recommender.__class__.__name__}', which does not provide " + f"a surrogate model. Surrogate models are only available for " + f"recommender subclasses of '{BayesianRecommender.__name__}'." + ) + def _add_version(dict_: dict) -> dict: """Add the package version to the given dictionary.""" diff --git a/baybe/recommenders/meta/base.py b/baybe/recommenders/meta/base.py index aff74bbc1..02acb5723 100644 --- a/baybe/recommenders/meta/base.py +++ b/baybe/recommenders/meta/base.py @@ -5,7 +5,7 @@ import cattrs import pandas as pd -from attrs import define +from attrs import define, field from baybe.objectives.base import Objective from baybe.recommenders.base import RecommenderProtocol @@ -20,6 +20,12 @@ class MetaRecommender(SerialMixin, RecommenderProtocol, ABC): """Abstract base class for all meta recommenders.""" + _current_recommender: PureRecommender | None = field(default=None, init=False) + """The current recommender.""" + + _has_been_used_for_recommendation: bool = field(default=False, init=False) + """Flag indicating if the current recommender has already been used.""" + @abstractmethod def select_recommender( self, @@ -31,21 +37,60 @@ def select_recommender( ) -> PureRecommender: """Select a pure recommender for the given experimentation context. - Args: - batch_size: - See :func:`baybe.recommenders.meta.base.MetaRecommender.recommend`. - searchspace: - See :func:`baybe.recommenders.meta.base.MetaRecommender.recommend`. - objective: - See :func:`baybe.recommenders.meta.base.MetaRecommender.recommend`. - measurements: - See :func:`baybe.recommenders.meta.base.MetaRecommender.recommend`. - pending_experiments: - See :func:`baybe.recommenders.meta.base.MetaRecommender.recommend`. - - Returns: - The selected recommender. + See :meth:`baybe.recommenders.base.RecommenderProtocol.recommend` for details + on the method arguments. + """ + + def get_current_recommender(self) -> PureRecommender: + """Get the current recommender, if available.""" + if self._current_recommender is None: + raise RuntimeError( + f"No recommendation has been requested from the " + f"'{self.__class__.__name__}' yet. Because the recommender is a " + f"'{MetaRecommender.__name__}', this means no actual recommender has " + f"been selected so far. The recommender will be available after the " + f"next '{self.recommend.__name__}' call." + ) + return self._current_recommender + + def get_next_recommender( + self, + batch_size: int, + searchspace: SearchSpace, + objective: Objective | None = None, + measurements: pd.DataFrame | None = None, + pending_experiments: pd.DataFrame | None = None, + ) -> PureRecommender: + """Get the recommender for the next recommendation. + + Returns the next recommender in row that has not yet been used for generating + recommendations. In case of multiple consecutive calls, this means that + the same recommender instance is returned until its :meth:`recommend` method + is called. + + See :meth:`baybe.recommenders.base.RecommenderProtocol.recommend` for details + on the method arguments. """ + # Check if the stored recommender instance can be returned + if ( + self._current_recommender is not None + and not self._has_been_used_for_recommendation + ): + recommender = self._current_recommender + + # Otherwise, fetch the next recommender waiting in row + else: + recommender = self.select_recommender( + batch_size=batch_size, + searchspace=searchspace, + objective=objective, + measurements=measurements, + pending_experiments=pending_experiments, + ) + self._current_recommender = recommender + self._has_been_used_for_recommendation = False + + return recommender def recommend( self, @@ -55,8 +100,8 @@ def recommend( measurements: pd.DataFrame | None = None, pending_experiments: pd.DataFrame | None = None, ) -> pd.DataFrame: - """See :func:`baybe.recommenders.base.RecommenderProtocol.recommend`.""" - recommender = self.select_recommender( + """See :meth:`baybe.recommenders.base.RecommenderProtocol.recommend`.""" + recommender = self.get_next_recommender( batch_size=batch_size, searchspace=searchspace, objective=objective, @@ -76,12 +121,14 @@ def recommend( } ) - return recommender.recommend( + recommendations = recommender.recommend( batch_size=batch_size, searchspace=searchspace, pending_experiments=pending_experiments, **optional_args, ) + self._has_been_used_for_recommendation = True + return recommendations # Register (un-)structure hooks diff --git a/baybe/surrogates/base.py b/baybe/surrogates/base.py index 6f891faba..1aae5ca60 100644 --- a/baybe/surrogates/base.py +++ b/baybe/surrogates/base.py @@ -16,6 +16,7 @@ UnstructuredValue, UnstructureHook, ) +from joblib.hashing import hash from baybe.exceptions import ModelNotTrainedError from baybe.objectives.base import Objective @@ -97,6 +98,12 @@ class Surrogate(ABC, SurrogateProtocol, SerialMixin): _searchspace: SearchSpace | None = field(init=False, default=None, eq=False) """The search space on which the surrogate operates. Available after fitting.""" + _objective: Objective | None = field(init=False, default=None, eq=False) + """The objective for which the surrogate was trained. Available after fitting.""" + + _measurements_hash = field(init=False, default=None, eq=False) + """The hash of the data the surrogate was trained on.""" + # TODO: type should be # `botorch.models.transforms.outcome.Standardize | _NoTransform` # but is currently omitted due to: @@ -275,6 +282,14 @@ def fit( """ # TODO: consider adding a validation step for `measurements` + # When the context is unchanged, no retraining is necessary + if ( + searchspace == self._searchspace + and objective == self._objective + and hash(measurements) == self._measurements_hash + ): + return + # Check if transfer learning capabilities are needed if (searchspace.n_tasks > 1) and (not self.supports_transfer_learning): raise ValueError( @@ -289,8 +304,10 @@ def fit( "Continuous search spaces are currently only supported by GPs." ) - # Remember on which search space the model is trained + # Remember the training context self._searchspace = searchspace + self._objective = objective + self._measurements_hash = hash(measurements) # Create context-specific transformations self._input_scaler = self._make_input_scaler(searchspace) From b1cc6898e0af557135f0a223e112d1dd832053a8 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Fri, 30 Aug 2024 10:36:38 +0200 Subject: [PATCH 02/20] Add caching test --- tests/test_surrogate.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 tests/test_surrogate.py diff --git a/tests/test_surrogate.py b/tests/test_surrogate.py new file mode 100644 index 000000000..5f6869bc7 --- /dev/null +++ b/tests/test_surrogate.py @@ -0,0 +1,26 @@ +"""Surrogate tests.""" + +from unittest.mock import patch + +from baybe.recommenders.pure.nonpredictive.sampling import RandomRecommender +from baybe.surrogates.gaussian_process.core import GaussianProcessSurrogate +from baybe.utils.dataframe import add_fake_results + + +@patch.object(GaussianProcessSurrogate, "_fit") +def test_caching(patched, searchspace, objective): + """A second fit call with the same context does not trigger retraining.""" + # Prepare the setting + measurements = RandomRecommender().recommend(3, searchspace, objective) + add_fake_results(measurements, objective.targets) + surrogate = GaussianProcessSurrogate() + + # First call + surrogate.fit(searchspace, objective, measurements) + patched.assert_called() + + patched.reset_mock() + + # Second call + surrogate.fit(searchspace, objective, measurements) + patched.assert_not_called() From 5c746ea582e5cdcfb8704bf23556543fe32d6095 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Fri, 30 Aug 2024 10:37:38 +0200 Subject: [PATCH 03/20] Update CHANGELOG.md --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8d12b85c8..c93e6f631 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Pure recommenders now have the `allow_recommending_pending_experiments` flag, controlling whether pending experiments are excluded from candidates in purely discrete search spaces +- `get_surrogate` and `posterior` methods to `Campaign` ### Changed - The transition from experimental to computational representation no longer happens From 6e944df2e8bdb46a598165d039441e1f29cf033e Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Fri, 30 Aug 2024 11:39:19 +0200 Subject: [PATCH 04/20] Add joblib dependency --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index e681ebe8e..5e98261d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ "exceptiongroup", "funcy>=1.17", "gpytorch>=1.9.1,<2", + "joblib>1.4.0,<2", "ngboost>=0.3.12,<1", "numpy>=1.24.1,<2", "pandas>=1.4.2", From 05394365362e8913a7d47feee49a8959913bc80e Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Fri, 30 Aug 2024 11:39:43 +0200 Subject: [PATCH 05/20] Add upper version limits to remaining core dependencies --- pyproject.toml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5e98261d4..fe4f39793 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,12 +32,12 @@ dependencies = [ "botorch>=0.9.3,<1", "cattrs>=23.2.0", "exceptiongroup", - "funcy>=1.17", + "funcy>=1.17,<2", "gpytorch>=1.9.1,<2", "joblib>1.4.0,<2", "ngboost>=0.3.12,<1", "numpy>=1.24.1,<2", - "pandas>=1.4.2", + "pandas>=1.4.2,<3", "protobuf<=3.20.3,<4", "scikit-learn>=1.1.1,<2", "scikit-learn-extra>=0.3.0,<1", @@ -47,10 +47,10 @@ dependencies = [ "typing_extensions>=4.7.0", # Telemetry: - "opentelemetry-sdk>=1.16.0", - "opentelemetry-propagator-aws-xray>=1.0.0", - "opentelemetry-exporter-otlp>=1.16.0", - "opentelemetry-sdk-extension-aws>=2.0.0", + "opentelemetry-sdk>=1.16.0,<2", + "opentelemetry-propagator-aws-xray>=1.0.0,<2", + "opentelemetry-exporter-otlp>=1.16.0,<2", + "opentelemetry-sdk-extension-aws>=2.0.0,<3", ] [project.urls] From fdf25092b2fcfcef46c272d8494499c47968b287 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Fri, 30 Aug 2024 13:33:05 +0200 Subject: [PATCH 06/20] Fix mypy issues --- baybe/campaign.py | 20 ++++++++++++++++++-- baybe/surrogates/base.py | 3 +++ mypy.ini | 2 +- 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/baybe/campaign.py b/baybe/campaign.py index e897d43ef..91ac4d256 100644 --- a/baybe/campaign.py +++ b/baybe/campaign.py @@ -12,6 +12,7 @@ from attrs.converters import optional from attrs.validators import instance_of +from baybe.exceptions import IncompatibilityError from baybe.objectives.base import Objective, to_objective from baybe.parameters.base import Parameter from baybe.recommenders.base import RecommenderProtocol @@ -25,7 +26,7 @@ validate_searchspace_from_config, ) from baybe.serialization import SerialMixin, converter -from baybe.surrogates.base import Surrogate +from baybe.surrogates.base import SurrogateProtocol from baybe.targets.base import Target from baybe.telemetry import ( TELEM_LABELS, @@ -286,14 +287,23 @@ def posterior(self, candidates: pd.DataFrame) -> Posterior: candidates: The candidate points in experimental recommendations. For details, see :meth:`baybe.surrogates.base.Surrogate.posterior`. + Raises: + IncompatibilityError: If the underlying surrogate model exposes no + method for computing the posterior distribution. + Returns: Posterior: The corresponding posterior object. For details, see :meth:`baybe.surrogates.base.Surrogate.posterior`. """ surrogate = self.get_surrogate() + if not hasattr(surrogate, method_name := "posterior"): + raise IncompatibilityError( + f"The used surrogate type '{surrogate.__class__.__name__}' does not " + f"provide a '{method_name}' method." + ) return surrogate.posterior(candidates) - def get_surrogate(self) -> Surrogate: + def get_surrogate(self) -> SurrogateProtocol: """Get the current surrogate model. Raises: @@ -302,6 +312,12 @@ def get_surrogate(self) -> Surrogate: Returns: Surrogate: The surrogate of the current recommender. """ + if self.objective is None: + raise IncompatibilityError( + f"No surrogate is available since no '{Objective.__name__}' is defined." + ) + + pure_recommender: RecommenderProtocol if isinstance(self.recommender, MetaRecommender): pure_recommender = self.recommender.get_current_recommender() else: diff --git a/baybe/surrogates/base.py b/baybe/surrogates/base.py index 1aae5ca60..dd7e3b71e 100644 --- a/baybe/surrogates/base.py +++ b/baybe/surrogates/base.py @@ -65,6 +65,9 @@ class _NoTransform(Enum): class SurrogateProtocol(Protocol): """Type protocol specifying the interface surrogate models need to implement.""" + # TODO: Final layout still to be optimized. For example, shall we require a + # `posterior` method? + def fit( self, searchspace: SearchSpace, diff --git a/mypy.ini b/mypy.ini index dfe237e75..7a9737aea 100644 --- a/mypy.ini +++ b/mypy.ini @@ -21,7 +21,7 @@ ignore_missing_imports = True [mypy-gpytorch.*] ignore_missing_imports = True -[mypy-joblib] +[mypy-joblib.*] ignore_missing_imports = True [mypy-mordred] From cc73c73b5d2f2a493832fe1855149f497cc7930b Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Fri, 30 Aug 2024 13:43:44 +0200 Subject: [PATCH 07/20] Update lockfile --- .lockfiles/py310-dev.lock | 1 + 1 file changed, 1 insertion(+) diff --git a/.lockfiles/py310-dev.lock b/.lockfiles/py310-dev.lock index 0708c013c..d2ea0c649 100644 --- a/.lockfiles/py310-dev.lock +++ b/.lockfiles/py310-dev.lock @@ -255,6 +255,7 @@ jinja2==3.1.4 # torch joblib==1.4.2 # via + # baybe (pyproject.toml) # scikit-learn # xyzpy json5==0.9.25 From ae985929b9884eac88ca788db7de7a90bd49502e Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Fri, 30 Aug 2024 20:57:30 +0200 Subject: [PATCH 08/20] Add get_surrogate method to BayesianRecommender --- CHANGELOG.md | 1 + baybe/campaign.py | 6 ++-- baybe/recommenders/pure/bayesian/base.py | 37 ++++++++++++++++++++---- tests/test_deprecations.py | 7 +++++ 4 files changed, 42 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c93e6f631..763ca24f1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -63,6 +63,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Deprecations - The role of `register_custom_architecture` has been taken over by `baybe.surrogates.base.SurrogateProtocol` +- `BayesianRecommender.surrogate_model` has been replaced with `get_surrogate` ## [0.10.0] - 2024-08-02 ### Breaking Changes diff --git a/baybe/campaign.py b/baybe/campaign.py index 91ac4d256..d4ab7d4f8 100644 --- a/baybe/campaign.py +++ b/baybe/campaign.py @@ -324,9 +324,9 @@ def get_surrogate(self) -> SurrogateProtocol: pure_recommender = self.recommender if isinstance(pure_recommender, BayesianRecommender): - surrogate = pure_recommender.surrogate_model - surrogate.fit(self.searchspace, self.objective, self.measurements) - return surrogate + return pure_recommender.get_surrogate( + self.searchspace, self.objective, self.measurements + ) else: raise RuntimeError( f"The current recommender is of type " diff --git a/baybe/recommenders/pure/bayesian/base.py b/baybe/recommenders/pure/bayesian/base.py index 48d4c2d63..7f85d5dfd 100644 --- a/baybe/recommenders/pure/bayesian/base.py +++ b/baybe/recommenders/pure/bayesian/base.py @@ -1,9 +1,10 @@ """Base class for all Bayesian recommenders.""" +import warnings from abc import ABC import pandas as pd -from attrs import define, field +from attrs import define, field, fields from baybe.acquisition.acqfs import qLogExpectedImprovement from baybe.acquisition.base import AcquisitionFunction @@ -20,7 +21,9 @@ class BayesianRecommender(PureRecommender, ABC): """An abstract class for Bayesian Recommenders.""" - surrogate_model: SurrogateProtocol = field(factory=GaussianProcessSurrogate) + _surrogate_model: SurrogateProtocol = field( + alias="surrogate_model", factory=GaussianProcessSurrogate + ) """The used surrogate model.""" acquisition_function: AcquisitionFunction = field( @@ -43,6 +46,28 @@ def _validate_deprecated_argument(self, _, value) -> None: "The parameter has been renamed to 'acquisition_function'." ) + @property + def surrogate_model(self) -> SurrogateProtocol: + """Deprecated!""" + warnings.warn( + f"Accessing the surrogate model via 'surrogate_model' has been " + f"deprecated. Use '{self.get_surrogate.__name__}' instead to get the " + f"trained model instance (or " + f"'{fields(type(self))._surrogate_model.name}' to access the raw object).", + DeprecationWarning, + ) + return self._surrogate_model + + def get_surrogate( + self, + searchspace: SearchSpace, + objective: Objective, + measurements: pd.DataFrame, + ) -> SurrogateProtocol: + """Get the trained surrogate model.""" + self._surrogate_model.fit(searchspace, objective, measurements) + return self._surrogate_model + def _setup_botorch_acqf( self, searchspace: SearchSpace, @@ -51,9 +76,9 @@ def _setup_botorch_acqf( pending_experiments: pd.DataFrame | None = None, ) -> None: """Create the acquisition function for the current training data.""" # noqa: E501 - self.surrogate_model.fit(searchspace, objective, measurements) + surrogate = self.get_surrogate(searchspace, objective, measurements) self._botorch_acqf = self.acquisition_function.to_botorch( - self.surrogate_model, + surrogate, searchspace, objective, measurements, @@ -83,12 +108,12 @@ def recommend( # noqa: D102 ) if ( - isinstance(self.surrogate_model, IndependentGaussianSurrogate) + isinstance(self._surrogate_model, IndependentGaussianSurrogate) and batch_size > 1 ): raise InvalidSurrogateModelError( f"The specified surrogate model of type " - f"'{self.surrogate_model.__class__.__name__}' " + f"'{self._surrogate_model.__class__.__name__}' " f"cannot be used for batch recommendation." ) diff --git a/tests/test_deprecations.py b/tests/test_deprecations.py index fc2e83a6c..cf53c7625 100644 --- a/tests/test_deprecations.py +++ b/tests/test_deprecations.py @@ -117,3 +117,10 @@ def test_deprecated_surrogate_registration(): with pytest.raises(DeprecationError): register_custom_architecture() + + +def test_deprecated_surrogate_access(): + """Public attribute access to the surrogate model raises a warning.""" + recommender = BotorchRecommender() + with pytest.warns(DeprecationWarning): + recommender.surrogate_model From 8052859204e254dac7a9331c49e15818b22cb37a Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Fri, 30 Aug 2024 21:02:38 +0200 Subject: [PATCH 09/20] Disable gradient calculation for posteriors in campaigns --- baybe/campaign.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/baybe/campaign.py b/baybe/campaign.py index d4ab7d4f8..fab361d81 100644 --- a/baybe/campaign.py +++ b/baybe/campaign.py @@ -301,7 +301,11 @@ def posterior(self, candidates: pd.DataFrame) -> Posterior: f"The used surrogate type '{surrogate.__class__.__name__}' does not " f"provide a '{method_name}' method." ) - return surrogate.posterior(candidates) + + import torch + + with torch.no_grad(): + return surrogate.posterior(candidates) def get_surrogate(self) -> SurrogateProtocol: """Get the current surrogate model. From 14a57b0594d99e2b4374eabecbaa9ab136dd1e46 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Wed, 4 Sep 2024 22:20:34 +0200 Subject: [PATCH 10/20] Limit surrogate access to single untransformed targets --- baybe/campaign.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/baybe/campaign.py b/baybe/campaign.py index fab361d81..5ee66b638 100644 --- a/baybe/campaign.py +++ b/baybe/campaign.py @@ -14,6 +14,7 @@ from baybe.exceptions import IncompatibilityError from baybe.objectives.base import Objective, to_objective +from baybe.objectives.single import SingleTargetObjective from baybe.parameters.base import Parameter from baybe.recommenders.base import RecommenderProtocol from baybe.recommenders.meta.base import MetaRecommender @@ -28,6 +29,7 @@ from baybe.serialization import SerialMixin, converter from baybe.surrogates.base import SurrogateProtocol from baybe.targets.base import Target +from baybe.targets.numerical import NumericalTarget from baybe.telemetry import ( TELEM_LABELS, telemetry_record_recommended_measurement_percentage, @@ -316,6 +318,18 @@ def get_surrogate(self) -> SurrogateProtocol: Returns: Surrogate: The surrogate of the current recommender. """ + # TODO: remove temporary restriction when target transformations can be handled + match self.objective: + case SingleTargetObjective( + _target=NumericalTarget(bounds=b) + ) if not b.is_bounded: + pass + case _: + raise NotImplementedError( + "Surrogate model access is currently only supported for a single " + "untransformed target." + ) + if self.objective is None: raise IncompatibilityError( f"No surrogate is available since no '{Objective.__name__}' is defined." From 01092a58d774b45fcb4b792978472e45abce7df6 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 5 Sep 2024 12:49:03 +0200 Subject: [PATCH 11/20] Rename meta recommender flag --- baybe/recommenders/meta/base.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/baybe/recommenders/meta/base.py b/baybe/recommenders/meta/base.py index 02acb5723..3d2355675 100644 --- a/baybe/recommenders/meta/base.py +++ b/baybe/recommenders/meta/base.py @@ -23,7 +23,7 @@ class MetaRecommender(SerialMixin, RecommenderProtocol, ABC): _current_recommender: PureRecommender | None = field(default=None, init=False) """The current recommender.""" - _has_been_used_for_recommendation: bool = field(default=False, init=False) + _current_recommender_was_used: bool = field(default=False, init=False) """Flag indicating if the current recommender has already been used.""" @abstractmethod @@ -74,7 +74,7 @@ def get_next_recommender( # Check if the stored recommender instance can be returned if ( self._current_recommender is not None - and not self._has_been_used_for_recommendation + and not self._current_recommender_was_used ): recommender = self._current_recommender @@ -88,7 +88,7 @@ def get_next_recommender( pending_experiments=pending_experiments, ) self._current_recommender = recommender - self._has_been_used_for_recommendation = False + self._current_recommender_was_used = False return recommender @@ -127,7 +127,7 @@ def recommend( pending_experiments=pending_experiments, **optional_args, ) - self._has_been_used_for_recommendation = True + self._current_recommender_was_used = True return recommendations From 3ef526ce3d5e69d4865bf47389068c4d14b54dfb Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 5 Sep 2024 13:12:42 +0200 Subject: [PATCH 12/20] Add temporary workaround to enable meta data attribute serialization --- baybe/recommenders/meta/base.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/baybe/recommenders/meta/base.py b/baybe/recommenders/meta/base.py index 3d2355675..ba4106361 100644 --- a/baybe/recommenders/meta/base.py +++ b/baybe/recommenders/meta/base.py @@ -20,11 +20,24 @@ class MetaRecommender(SerialMixin, RecommenderProtocol, ABC): """Abstract base class for all meta recommenders.""" - _current_recommender: PureRecommender | None = field(default=None, init=False) - """The current recommender.""" - - _current_recommender_was_used: bool = field(default=False, init=False) - """Flag indicating if the current recommender has already been used.""" + # TODO: The attributes should be `init=False` but this currently prevents them from + # being serialized. The reason is that setting `_cattrs_include_init_false=True` + # for this class has currently no effect when serializing it as + # a `RecommenderProtocol`, since the hook of the latter does not reuse the + # hook of the actual class. Fix is already planned and also needed for other + # reasons. Until that, as a workaround, we expose the attributes as "private" + # attributes. + + _current_recommender: PureRecommender | None = field( + alias="_current_recommender", default=None, kw_only=True + ) + """The current recommender. (For internal use only!)""" + + _current_recommender_was_used: bool = field( + alias="_current_recommender_was_used", default=False, kw_only=True + ) + """Flag indicating if the current recommender has already been used. + (For internal use only!)""" @abstractmethod def select_recommender( From c9b2a5ca6e91a97b7503320e75f726bf8c37107c Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Fri, 6 Sep 2024 15:20:43 +0200 Subject: [PATCH 13/20] Add missing type annotation for measurement hash --- baybe/surrogates/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/baybe/surrogates/base.py b/baybe/surrogates/base.py index dd7e3b71e..7c86519a9 100644 --- a/baybe/surrogates/base.py +++ b/baybe/surrogates/base.py @@ -104,7 +104,7 @@ class Surrogate(ABC, SurrogateProtocol, SerialMixin): _objective: Objective | None = field(init=False, default=None, eq=False) """The objective for which the surrogate was trained. Available after fitting.""" - _measurements_hash = field(init=False, default=None, eq=False) + _measurements_hash: str = field(init=False, default=None, eq=False) """The hash of the data the surrogate was trained on.""" # TODO: type should be From e0a4258e38a4e1d8b895433d70e6d9a8c26cf18c Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Fri, 6 Sep 2024 16:06:08 +0200 Subject: [PATCH 14/20] Make protocol classes slotted --- baybe/recommenders/base.py | 4 ++++ baybe/serialization/mixin.py | 2 +- baybe/surrogates/base.py | 4 ++++ 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/baybe/recommenders/base.py b/baybe/recommenders/base.py index d4de83b85..92af59bd9 100644 --- a/baybe/recommenders/base.py +++ b/baybe/recommenders/base.py @@ -15,6 +15,10 @@ class RecommenderProtocol(Protocol): """Type protocol specifying the interface recommenders need to implement.""" + # Use slots so that derived classes also remain slotted + # See also: https://www.attrs.org/en/stable/glossary.html#term-slotted-classes + __slots__ = () + def recommend( self, batch_size: int, diff --git a/baybe/serialization/mixin.py b/baybe/serialization/mixin.py index d423cb572..8f98fc63c 100644 --- a/baybe/serialization/mixin.py +++ b/baybe/serialization/mixin.py @@ -11,7 +11,7 @@ class SerialMixin: """A mixin class providing serialization functionality.""" - # Use slots so that the derived classes also remain slotted + # Use slots so that derived classes also remain slotted # See also: https://www.attrs.org/en/stable/glossary.html#term-slotted-classes __slots__ = () diff --git a/baybe/surrogates/base.py b/baybe/surrogates/base.py index 7c86519a9..4871d0461 100644 --- a/baybe/surrogates/base.py +++ b/baybe/surrogates/base.py @@ -65,6 +65,10 @@ class _NoTransform(Enum): class SurrogateProtocol(Protocol): """Type protocol specifying the interface surrogate models need to implement.""" + # Use slots so that derived classes also remain slotted + # See also: https://www.attrs.org/en/stable/glossary.html#term-slotted-classes + __slots__ = () + # TODO: Final layout still to be optimized. For example, shall we require a # `posterior` method? From 24b4d419476c7efc33aefaa70188f54bb9f3481e Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Fri, 6 Sep 2024 16:06:48 +0200 Subject: [PATCH 15/20] Add missing _input_scaler attribute --- baybe/surrogates/base.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/baybe/surrogates/base.py b/baybe/surrogates/base.py index 4871d0461..1b4e80865 100644 --- a/baybe/surrogates/base.py +++ b/baybe/surrogates/base.py @@ -111,8 +111,14 @@ class Surrogate(ABC, SurrogateProtocol, SerialMixin): _measurements_hash: str = field(init=False, default=None, eq=False) """The hash of the data the surrogate was trained on.""" + _input_scaler: ColumnTransformer | None = field(init=False, default=None, eq=False) + """Scaler for transforming input values. Available after fitting. + + Scales a tensor containing parameter configurations in computational representation + to make them digestible for the model-specific, scale-agnostic posterior logic.""" + # TODO: type should be - # `botorch.models.transforms.outcome.Standardize | _NoTransform` + # `botorch.models.transforms.outcome.Standardize | _NoTransform` | None # but is currently omitted due to: # https://github.com/python-attrs/cattrs/issues/531 _output_scaler = field(init=False, default=None, eq=False) From ab564e6ee623e6da836f43215c640dcd7c431380 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Fri, 6 Sep 2024 16:22:01 +0200 Subject: [PATCH 16/20] Temporarily disable slots for recommenders --- baybe/recommenders/pure/base.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/baybe/recommenders/pure/base.py b/baybe/recommenders/pure/base.py index 10ec66fd3..1004389b8 100644 --- a/baybe/recommenders/pure/base.py +++ b/baybe/recommenders/pure/base.py @@ -15,7 +15,10 @@ from baybe.searchspace.discrete import SubspaceDiscrete -@define +# TODO: Slots are currently disabled since they also block the monkeypatching necessary +# to use `register_hooks`. Probably, we need to update our documentation and +# explain how to work around that before we re-enable slots. +@define(slots=False) class PureRecommender(ABC, RecommenderProtocol): """Abstract base class for all pure recommenders.""" From 79cb6f66c1aaeae59543b8ccf39d1275147caf9f Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Fri, 6 Sep 2024 16:59:39 +0200 Subject: [PATCH 17/20] Improve workaround and also apply it for surrogate model renaming --- baybe/recommenders/base.py | 25 +++++++++++++++++++++++-- baybe/recommenders/meta/base.py | 23 +++++------------------ 2 files changed, 28 insertions(+), 20 deletions(-) diff --git a/baybe/recommenders/base.py b/baybe/recommenders/base.py index 92af59bd9..10a8157c7 100644 --- a/baybe/recommenders/base.py +++ b/baybe/recommenders/base.py @@ -4,6 +4,7 @@ import cattrs import pandas as pd +from cattrs import override from baybe.objectives.base import Objective from baybe.searchspace import SearchSpace @@ -51,15 +52,35 @@ def recommend( ... +# TODO: The workarounds below are currently required since the hooks created through +# `unstructure_base` and `get_base_structure_hook` do not reuse the hooks of the +# actual class, hence we cannot control things there. Fix is already planned and also +# needed for other reasons. + # Register (un-)structure hooks converter.register_unstructure_hook( RecommenderProtocol, lambda x: unstructure_base( x, # TODO: Remove once deprecation got expired: - overrides=dict(acquisition_function_cls=cattrs.override(omit=True)), + overrides=dict( + acquisition_function_cls=cattrs.override(omit=True), + # Temporary workaround (see TODO note above) + _surrogate_model=override(rename="surrogate_model"), + _current_recommender=override(omit=False), + _current_recommender_was_used=override(omit=False), + ), ), ) converter.register_structure_hook( - RecommenderProtocol, get_base_structure_hook(RecommenderProtocol) + RecommenderProtocol, + get_base_structure_hook( + RecommenderProtocol, + # Temporary workaround (see TODO note above) + overrides=dict( + _surrogate_model=override(rename="surrogate_model"), + _current_recommender=override(omit=False), + _current_recommender_was_used=override(omit=False), + ), + ), ) diff --git a/baybe/recommenders/meta/base.py b/baybe/recommenders/meta/base.py index ba4106361..3d2355675 100644 --- a/baybe/recommenders/meta/base.py +++ b/baybe/recommenders/meta/base.py @@ -20,24 +20,11 @@ class MetaRecommender(SerialMixin, RecommenderProtocol, ABC): """Abstract base class for all meta recommenders.""" - # TODO: The attributes should be `init=False` but this currently prevents them from - # being serialized. The reason is that setting `_cattrs_include_init_false=True` - # for this class has currently no effect when serializing it as - # a `RecommenderProtocol`, since the hook of the latter does not reuse the - # hook of the actual class. Fix is already planned and also needed for other - # reasons. Until that, as a workaround, we expose the attributes as "private" - # attributes. - - _current_recommender: PureRecommender | None = field( - alias="_current_recommender", default=None, kw_only=True - ) - """The current recommender. (For internal use only!)""" - - _current_recommender_was_used: bool = field( - alias="_current_recommender_was_used", default=False, kw_only=True - ) - """Flag indicating if the current recommender has already been used. - (For internal use only!)""" + _current_recommender: PureRecommender | None = field(default=None, init=False) + """The current recommender.""" + + _current_recommender_was_used: bool = field(default=False, init=False) + """Flag indicating if the current recommender has already been used.""" @abstractmethod def select_recommender( From 0aaf300458be18948ad8590eafe2bc266af6a453 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Fri, 6 Sep 2024 18:07:10 +0200 Subject: [PATCH 18/20] Silence mypy error --- baybe/surrogates/base.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/baybe/surrogates/base.py b/baybe/surrogates/base.py index 1b4e80865..c07ef9b13 100644 --- a/baybe/surrogates/base.py +++ b/baybe/surrogates/base.py @@ -235,6 +235,10 @@ def _posterior_comp(self, candidates_comp: Tensor, /) -> Posterior: The same :class:`botorch.posteriors.Posterior` object as returned via :meth:`baybe.surrogates.base.Surrogate.posterior`. """ + # FIXME[typing]: It seems there is currently no better way to inform the type + # checker that the attribute is available at the time of the function call + assert self._input_scaler is not None + p = self._posterior(self._input_scaler.transform(candidates_comp)) if self._output_scaler is not _IDENTITY_TRANSFORM: p = self._output_scaler.untransform_posterior(p) From 737c93723bfae5d188aa5b63e7c4524bd74110d7 Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Sat, 7 Sep 2024 20:15:13 +0200 Subject: [PATCH 19/20] Add comment --- baybe/recommenders/pure/bayesian/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/baybe/recommenders/pure/bayesian/base.py b/baybe/recommenders/pure/bayesian/base.py index 7f85d5dfd..aeb39bfab 100644 --- a/baybe/recommenders/pure/bayesian/base.py +++ b/baybe/recommenders/pure/bayesian/base.py @@ -65,7 +65,9 @@ def get_surrogate( measurements: pd.DataFrame, ) -> SurrogateProtocol: """Get the trained surrogate model.""" + # This fit applies internal caching and does not necessarily involve computation self._surrogate_model.fit(searchspace, objective, measurements) + return self._surrogate_model def _setup_botorch_acqf( From 46441b9afda30a9b4990a48a554351aceb91e26b Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Sun, 8 Sep 2024 23:04:37 +0200 Subject: [PATCH 20/20] Use set of recommender ids --- baybe/recommenders/base.py | 4 ++-- baybe/recommenders/meta/base.py | 10 +++++----- baybe/serialization/core.py | 6 ++++-- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/baybe/recommenders/base.py b/baybe/recommenders/base.py index 10a8157c7..398d23886 100644 --- a/baybe/recommenders/base.py +++ b/baybe/recommenders/base.py @@ -68,7 +68,7 @@ def recommend( # Temporary workaround (see TODO note above) _surrogate_model=override(rename="surrogate_model"), _current_recommender=override(omit=False), - _current_recommender_was_used=override(omit=False), + _used_recommender_ids=override(omit=False), ), ), ) @@ -80,7 +80,7 @@ def recommend( overrides=dict( _surrogate_model=override(rename="surrogate_model"), _current_recommender=override(omit=False), - _current_recommender_was_used=override(omit=False), + _used_recommender_ids=override(omit=False), ), ), ) diff --git a/baybe/recommenders/meta/base.py b/baybe/recommenders/meta/base.py index 3d2355675..60dcfdead 100644 --- a/baybe/recommenders/meta/base.py +++ b/baybe/recommenders/meta/base.py @@ -23,8 +23,8 @@ class MetaRecommender(SerialMixin, RecommenderProtocol, ABC): _current_recommender: PureRecommender | None = field(default=None, init=False) """The current recommender.""" - _current_recommender_was_used: bool = field(default=False, init=False) - """Flag indicating if the current recommender has already been used.""" + _used_recommender_ids: set[int] = field(factory=set, init=False) + """Set of ids from recommenders that were used by this meta recommender.""" @abstractmethod def select_recommender( @@ -74,7 +74,7 @@ def get_next_recommender( # Check if the stored recommender instance can be returned if ( self._current_recommender is not None - and not self._current_recommender_was_used + and id(self._current_recommender) not in self._used_recommender_ids ): recommender = self._current_recommender @@ -88,7 +88,6 @@ def get_next_recommender( pending_experiments=pending_experiments, ) self._current_recommender = recommender - self._current_recommender_was_used = False return recommender @@ -127,7 +126,8 @@ def recommend( pending_experiments=pending_experiments, **optional_args, ) - self._current_recommender_was_used = True + self._used_recommender_ids.add(id(recommender)) + return recommendations diff --git a/baybe/serialization/core.py b/baybe/serialization/core.py index ea0cbcf43..2947d4cd0 100644 --- a/baybe/serialization/core.py +++ b/baybe/serialization/core.py @@ -18,7 +18,9 @@ # TODO: This urgently needs the `forbid_extra_keys=True` flag, which requires us to # switch to the cattrs built-in subclass recommender. -converter = cattrs.Converter() +# Using GenConverter for built-in overrides for sets, see +# https://catt.rs/en/latest/indepth.html#customizing-collection-unstructuring +converter = cattrs.GenConverter(unstruct_collection_overrides={set: list}) """The default converter for (de-)serializing BayBE-related objects.""" configure_union_passthrough(bool | int | float | str, converter) @@ -158,6 +160,6 @@ def select_constructor_hook(specs: dict, cls: type[_T]) -> _T: return converter.structure_attrs_fromdict(specs, cls) -# Register un-/structure hooks +# Register custom un-/structure hooks converter.register_unstructure_hook(pd.DataFrame, _unstructure_dataframe_hook) converter.register_structure_hook(pd.DataFrame, _structure_dataframe_hook)