diff --git a/.lockfiles/py310-dev.lock b/.lockfiles/py310-dev.lock index 0708c013c..d2ea0c649 100644 --- a/.lockfiles/py310-dev.lock +++ b/.lockfiles/py310-dev.lock @@ -255,6 +255,7 @@ jinja2==3.1.4 # torch joblib==1.4.2 # via + # baybe (pyproject.toml) # scikit-learn # xyzpy json5==0.9.25 diff --git a/CHANGELOG.md b/CHANGELOG.md index 8d12b85c8..763ca24f1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Pure recommenders now have the `allow_recommending_pending_experiments` flag, controlling whether pending experiments are excluded from candidates in purely discrete search spaces +- `get_surrogate` and `posterior` methods to `Campaign` ### Changed - The transition from experimental to computational representation no longer happens @@ -62,6 +63,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Deprecations - The role of `register_custom_architecture` has been taken over by `baybe.surrogates.base.SurrogateProtocol` +- `BayesianRecommender.surrogate_model` has been replaced with `get_surrogate` ## [0.10.0] - 2024-08-02 ### Breaking Changes diff --git a/baybe/campaign.py b/baybe/campaign.py index 00d08a1d0..5ee66b638 100644 --- a/baybe/campaign.py +++ b/baybe/campaign.py @@ -3,6 +3,7 @@ from __future__ import annotations import json +from typing import TYPE_CHECKING import cattrs import numpy as np @@ -11,10 +12,14 @@ from attrs.converters import optional from attrs.validators import instance_of +from baybe.exceptions import IncompatibilityError from baybe.objectives.base import Objective, to_objective +from baybe.objectives.single import SingleTargetObjective from baybe.parameters.base import Parameter from baybe.recommenders.base import RecommenderProtocol +from baybe.recommenders.meta.base import MetaRecommender from baybe.recommenders.meta.sequential import TwoPhaseMetaRecommender +from baybe.recommenders.pure.bayesian.base import BayesianRecommender from baybe.searchspace.core import ( SearchSpace, SearchSpaceType, @@ -22,7 +27,9 @@ validate_searchspace_from_config, ) from baybe.serialization import SerialMixin, converter +from baybe.surrogates.base import SurrogateProtocol from baybe.targets.base import Target +from baybe.targets.numerical import NumericalTarget from baybe.telemetry import ( TELEM_LABELS, telemetry_record_recommended_measurement_percentage, @@ -31,6 +38,9 @@ from baybe.utils.boolean import eq_dataframe from baybe.utils.plotting import to_string +if TYPE_CHECKING: + from botorch.posteriors import Posterior + @define class Campaign(SerialMixin): @@ -269,6 +279,80 @@ def recommend( return rec + def posterior(self, candidates: pd.DataFrame) -> Posterior: + """Get the posterior predictive distribution for the given candidates. + + The predictive distribution is based on the surrogate model of the last used + recommender. + + Args: + candidates: The candidate points in experimental recommendations. + For details, see :meth:`baybe.surrogates.base.Surrogate.posterior`. + + Raises: + IncompatibilityError: If the underlying surrogate model exposes no + method for computing the posterior distribution. + + Returns: + Posterior: The corresponding posterior object. + For details, see :meth:`baybe.surrogates.base.Surrogate.posterior`. + """ + surrogate = self.get_surrogate() + if not hasattr(surrogate, method_name := "posterior"): + raise IncompatibilityError( + f"The used surrogate type '{surrogate.__class__.__name__}' does not " + f"provide a '{method_name}' method." + ) + + import torch + + with torch.no_grad(): + return surrogate.posterior(candidates) + + def get_surrogate(self) -> SurrogateProtocol: + """Get the current surrogate model. + + Raises: + RuntimeError: If the current recommender does not provide a surrogate model. + + Returns: + Surrogate: The surrogate of the current recommender. + """ + # TODO: remove temporary restriction when target transformations can be handled + match self.objective: + case SingleTargetObjective( + _target=NumericalTarget(bounds=b) + ) if not b.is_bounded: + pass + case _: + raise NotImplementedError( + "Surrogate model access is currently only supported for a single " + "untransformed target." + ) + + if self.objective is None: + raise IncompatibilityError( + f"No surrogate is available since no '{Objective.__name__}' is defined." + ) + + pure_recommender: RecommenderProtocol + if isinstance(self.recommender, MetaRecommender): + pure_recommender = self.recommender.get_current_recommender() + else: + pure_recommender = self.recommender + + if isinstance(pure_recommender, BayesianRecommender): + return pure_recommender.get_surrogate( + self.searchspace, self.objective, self.measurements + ) + else: + raise RuntimeError( + f"The current recommender is of type " + f"'{pure_recommender.__class__.__name__}', which does not provide " + f"a surrogate model. Surrogate models are only available for " + f"recommender subclasses of '{BayesianRecommender.__name__}'." + ) + def _add_version(dict_: dict) -> dict: """Add the package version to the given dictionary.""" diff --git a/baybe/recommenders/base.py b/baybe/recommenders/base.py index d4de83b85..398d23886 100644 --- a/baybe/recommenders/base.py +++ b/baybe/recommenders/base.py @@ -4,6 +4,7 @@ import cattrs import pandas as pd +from cattrs import override from baybe.objectives.base import Objective from baybe.searchspace import SearchSpace @@ -15,6 +16,10 @@ class RecommenderProtocol(Protocol): """Type protocol specifying the interface recommenders need to implement.""" + # Use slots so that derived classes also remain slotted + # See also: https://www.attrs.org/en/stable/glossary.html#term-slotted-classes + __slots__ = () + def recommend( self, batch_size: int, @@ -47,15 +52,35 @@ def recommend( ... +# TODO: The workarounds below are currently required since the hooks created through +# `unstructure_base` and `get_base_structure_hook` do not reuse the hooks of the +# actual class, hence we cannot control things there. Fix is already planned and also +# needed for other reasons. + # Register (un-)structure hooks converter.register_unstructure_hook( RecommenderProtocol, lambda x: unstructure_base( x, # TODO: Remove once deprecation got expired: - overrides=dict(acquisition_function_cls=cattrs.override(omit=True)), + overrides=dict( + acquisition_function_cls=cattrs.override(omit=True), + # Temporary workaround (see TODO note above) + _surrogate_model=override(rename="surrogate_model"), + _current_recommender=override(omit=False), + _used_recommender_ids=override(omit=False), + ), ), ) converter.register_structure_hook( - RecommenderProtocol, get_base_structure_hook(RecommenderProtocol) + RecommenderProtocol, + get_base_structure_hook( + RecommenderProtocol, + # Temporary workaround (see TODO note above) + overrides=dict( + _surrogate_model=override(rename="surrogate_model"), + _current_recommender=override(omit=False), + _used_recommender_ids=override(omit=False), + ), + ), ) diff --git a/baybe/recommenders/meta/base.py b/baybe/recommenders/meta/base.py index aff74bbc1..60dcfdead 100644 --- a/baybe/recommenders/meta/base.py +++ b/baybe/recommenders/meta/base.py @@ -5,7 +5,7 @@ import cattrs import pandas as pd -from attrs import define +from attrs import define, field from baybe.objectives.base import Objective from baybe.recommenders.base import RecommenderProtocol @@ -20,6 +20,12 @@ class MetaRecommender(SerialMixin, RecommenderProtocol, ABC): """Abstract base class for all meta recommenders.""" + _current_recommender: PureRecommender | None = field(default=None, init=False) + """The current recommender.""" + + _used_recommender_ids: set[int] = field(factory=set, init=False) + """Set of ids from recommenders that were used by this meta recommender.""" + @abstractmethod def select_recommender( self, @@ -31,22 +37,60 @@ def select_recommender( ) -> PureRecommender: """Select a pure recommender for the given experimentation context. - Args: - batch_size: - See :func:`baybe.recommenders.meta.base.MetaRecommender.recommend`. - searchspace: - See :func:`baybe.recommenders.meta.base.MetaRecommender.recommend`. - objective: - See :func:`baybe.recommenders.meta.base.MetaRecommender.recommend`. - measurements: - See :func:`baybe.recommenders.meta.base.MetaRecommender.recommend`. - pending_experiments: - See :func:`baybe.recommenders.meta.base.MetaRecommender.recommend`. - - Returns: - The selected recommender. + See :meth:`baybe.recommenders.base.RecommenderProtocol.recommend` for details + on the method arguments. """ + def get_current_recommender(self) -> PureRecommender: + """Get the current recommender, if available.""" + if self._current_recommender is None: + raise RuntimeError( + f"No recommendation has been requested from the " + f"'{self.__class__.__name__}' yet. Because the recommender is a " + f"'{MetaRecommender.__name__}', this means no actual recommender has " + f"been selected so far. The recommender will be available after the " + f"next '{self.recommend.__name__}' call." + ) + return self._current_recommender + + def get_next_recommender( + self, + batch_size: int, + searchspace: SearchSpace, + objective: Objective | None = None, + measurements: pd.DataFrame | None = None, + pending_experiments: pd.DataFrame | None = None, + ) -> PureRecommender: + """Get the recommender for the next recommendation. + + Returns the next recommender in row that has not yet been used for generating + recommendations. In case of multiple consecutive calls, this means that + the same recommender instance is returned until its :meth:`recommend` method + is called. + + See :meth:`baybe.recommenders.base.RecommenderProtocol.recommend` for details + on the method arguments. + """ + # Check if the stored recommender instance can be returned + if ( + self._current_recommender is not None + and id(self._current_recommender) not in self._used_recommender_ids + ): + recommender = self._current_recommender + + # Otherwise, fetch the next recommender waiting in row + else: + recommender = self.select_recommender( + batch_size=batch_size, + searchspace=searchspace, + objective=objective, + measurements=measurements, + pending_experiments=pending_experiments, + ) + self._current_recommender = recommender + + return recommender + def recommend( self, batch_size: int, @@ -55,8 +99,8 @@ def recommend( measurements: pd.DataFrame | None = None, pending_experiments: pd.DataFrame | None = None, ) -> pd.DataFrame: - """See :func:`baybe.recommenders.base.RecommenderProtocol.recommend`.""" - recommender = self.select_recommender( + """See :meth:`baybe.recommenders.base.RecommenderProtocol.recommend`.""" + recommender = self.get_next_recommender( batch_size=batch_size, searchspace=searchspace, objective=objective, @@ -76,12 +120,15 @@ def recommend( } ) - return recommender.recommend( + recommendations = recommender.recommend( batch_size=batch_size, searchspace=searchspace, pending_experiments=pending_experiments, **optional_args, ) + self._used_recommender_ids.add(id(recommender)) + + return recommendations # Register (un-)structure hooks diff --git a/baybe/recommenders/pure/base.py b/baybe/recommenders/pure/base.py index 10ec66fd3..1004389b8 100644 --- a/baybe/recommenders/pure/base.py +++ b/baybe/recommenders/pure/base.py @@ -15,7 +15,10 @@ from baybe.searchspace.discrete import SubspaceDiscrete -@define +# TODO: Slots are currently disabled since they also block the monkeypatching necessary +# to use `register_hooks`. Probably, we need to update our documentation and +# explain how to work around that before we re-enable slots. +@define(slots=False) class PureRecommender(ABC, RecommenderProtocol): """Abstract base class for all pure recommenders.""" diff --git a/baybe/recommenders/pure/bayesian/base.py b/baybe/recommenders/pure/bayesian/base.py index 48d4c2d63..aeb39bfab 100644 --- a/baybe/recommenders/pure/bayesian/base.py +++ b/baybe/recommenders/pure/bayesian/base.py @@ -1,9 +1,10 @@ """Base class for all Bayesian recommenders.""" +import warnings from abc import ABC import pandas as pd -from attrs import define, field +from attrs import define, field, fields from baybe.acquisition.acqfs import qLogExpectedImprovement from baybe.acquisition.base import AcquisitionFunction @@ -20,7 +21,9 @@ class BayesianRecommender(PureRecommender, ABC): """An abstract class for Bayesian Recommenders.""" - surrogate_model: SurrogateProtocol = field(factory=GaussianProcessSurrogate) + _surrogate_model: SurrogateProtocol = field( + alias="surrogate_model", factory=GaussianProcessSurrogate + ) """The used surrogate model.""" acquisition_function: AcquisitionFunction = field( @@ -43,6 +46,30 @@ def _validate_deprecated_argument(self, _, value) -> None: "The parameter has been renamed to 'acquisition_function'." ) + @property + def surrogate_model(self) -> SurrogateProtocol: + """Deprecated!""" + warnings.warn( + f"Accessing the surrogate model via 'surrogate_model' has been " + f"deprecated. Use '{self.get_surrogate.__name__}' instead to get the " + f"trained model instance (or " + f"'{fields(type(self))._surrogate_model.name}' to access the raw object).", + DeprecationWarning, + ) + return self._surrogate_model + + def get_surrogate( + self, + searchspace: SearchSpace, + objective: Objective, + measurements: pd.DataFrame, + ) -> SurrogateProtocol: + """Get the trained surrogate model.""" + # This fit applies internal caching and does not necessarily involve computation + self._surrogate_model.fit(searchspace, objective, measurements) + + return self._surrogate_model + def _setup_botorch_acqf( self, searchspace: SearchSpace, @@ -51,9 +78,9 @@ def _setup_botorch_acqf( pending_experiments: pd.DataFrame | None = None, ) -> None: """Create the acquisition function for the current training data.""" # noqa: E501 - self.surrogate_model.fit(searchspace, objective, measurements) + surrogate = self.get_surrogate(searchspace, objective, measurements) self._botorch_acqf = self.acquisition_function.to_botorch( - self.surrogate_model, + surrogate, searchspace, objective, measurements, @@ -83,12 +110,12 @@ def recommend( # noqa: D102 ) if ( - isinstance(self.surrogate_model, IndependentGaussianSurrogate) + isinstance(self._surrogate_model, IndependentGaussianSurrogate) and batch_size > 1 ): raise InvalidSurrogateModelError( f"The specified surrogate model of type " - f"'{self.surrogate_model.__class__.__name__}' " + f"'{self._surrogate_model.__class__.__name__}' " f"cannot be used for batch recommendation." ) diff --git a/baybe/serialization/core.py b/baybe/serialization/core.py index ea0cbcf43..2947d4cd0 100644 --- a/baybe/serialization/core.py +++ b/baybe/serialization/core.py @@ -18,7 +18,9 @@ # TODO: This urgently needs the `forbid_extra_keys=True` flag, which requires us to # switch to the cattrs built-in subclass recommender. -converter = cattrs.Converter() +# Using GenConverter for built-in overrides for sets, see +# https://catt.rs/en/latest/indepth.html#customizing-collection-unstructuring +converter = cattrs.GenConverter(unstruct_collection_overrides={set: list}) """The default converter for (de-)serializing BayBE-related objects.""" configure_union_passthrough(bool | int | float | str, converter) @@ -158,6 +160,6 @@ def select_constructor_hook(specs: dict, cls: type[_T]) -> _T: return converter.structure_attrs_fromdict(specs, cls) -# Register un-/structure hooks +# Register custom un-/structure hooks converter.register_unstructure_hook(pd.DataFrame, _unstructure_dataframe_hook) converter.register_structure_hook(pd.DataFrame, _structure_dataframe_hook) diff --git a/baybe/serialization/mixin.py b/baybe/serialization/mixin.py index d423cb572..8f98fc63c 100644 --- a/baybe/serialization/mixin.py +++ b/baybe/serialization/mixin.py @@ -11,7 +11,7 @@ class SerialMixin: """A mixin class providing serialization functionality.""" - # Use slots so that the derived classes also remain slotted + # Use slots so that derived classes also remain slotted # See also: https://www.attrs.org/en/stable/glossary.html#term-slotted-classes __slots__ = () diff --git a/baybe/surrogates/base.py b/baybe/surrogates/base.py index 6f891faba..c07ef9b13 100644 --- a/baybe/surrogates/base.py +++ b/baybe/surrogates/base.py @@ -16,6 +16,7 @@ UnstructuredValue, UnstructureHook, ) +from joblib.hashing import hash from baybe.exceptions import ModelNotTrainedError from baybe.objectives.base import Objective @@ -64,6 +65,13 @@ class _NoTransform(Enum): class SurrogateProtocol(Protocol): """Type protocol specifying the interface surrogate models need to implement.""" + # Use slots so that derived classes also remain slotted + # See also: https://www.attrs.org/en/stable/glossary.html#term-slotted-classes + __slots__ = () + + # TODO: Final layout still to be optimized. For example, shall we require a + # `posterior` method? + def fit( self, searchspace: SearchSpace, @@ -97,8 +105,20 @@ class Surrogate(ABC, SurrogateProtocol, SerialMixin): _searchspace: SearchSpace | None = field(init=False, default=None, eq=False) """The search space on which the surrogate operates. Available after fitting.""" + _objective: Objective | None = field(init=False, default=None, eq=False) + """The objective for which the surrogate was trained. Available after fitting.""" + + _measurements_hash: str = field(init=False, default=None, eq=False) + """The hash of the data the surrogate was trained on.""" + + _input_scaler: ColumnTransformer | None = field(init=False, default=None, eq=False) + """Scaler for transforming input values. Available after fitting. + + Scales a tensor containing parameter configurations in computational representation + to make them digestible for the model-specific, scale-agnostic posterior logic.""" + # TODO: type should be - # `botorch.models.transforms.outcome.Standardize | _NoTransform` + # `botorch.models.transforms.outcome.Standardize | _NoTransform` | None # but is currently omitted due to: # https://github.com/python-attrs/cattrs/issues/531 _output_scaler = field(init=False, default=None, eq=False) @@ -215,6 +235,10 @@ def _posterior_comp(self, candidates_comp: Tensor, /) -> Posterior: The same :class:`botorch.posteriors.Posterior` object as returned via :meth:`baybe.surrogates.base.Surrogate.posterior`. """ + # FIXME[typing]: It seems there is currently no better way to inform the type + # checker that the attribute is available at the time of the function call + assert self._input_scaler is not None + p = self._posterior(self._input_scaler.transform(candidates_comp)) if self._output_scaler is not _IDENTITY_TRANSFORM: p = self._output_scaler.untransform_posterior(p) @@ -275,6 +299,14 @@ def fit( """ # TODO: consider adding a validation step for `measurements` + # When the context is unchanged, no retraining is necessary + if ( + searchspace == self._searchspace + and objective == self._objective + and hash(measurements) == self._measurements_hash + ): + return + # Check if transfer learning capabilities are needed if (searchspace.n_tasks > 1) and (not self.supports_transfer_learning): raise ValueError( @@ -289,8 +321,10 @@ def fit( "Continuous search spaces are currently only supported by GPs." ) - # Remember on which search space the model is trained + # Remember the training context self._searchspace = searchspace + self._objective = objective + self._measurements_hash = hash(measurements) # Create context-specific transformations self._input_scaler = self._make_input_scaler(searchspace) diff --git a/mypy.ini b/mypy.ini index dfe237e75..7a9737aea 100644 --- a/mypy.ini +++ b/mypy.ini @@ -21,7 +21,7 @@ ignore_missing_imports = True [mypy-gpytorch.*] ignore_missing_imports = True -[mypy-joblib] +[mypy-joblib.*] ignore_missing_imports = True [mypy-mordred] diff --git a/pyproject.toml b/pyproject.toml index e681ebe8e..fe4f39793 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,11 +32,12 @@ dependencies = [ "botorch>=0.9.3,<1", "cattrs>=23.2.0", "exceptiongroup", - "funcy>=1.17", + "funcy>=1.17,<2", "gpytorch>=1.9.1,<2", + "joblib>1.4.0,<2", "ngboost>=0.3.12,<1", "numpy>=1.24.1,<2", - "pandas>=1.4.2", + "pandas>=1.4.2,<3", "protobuf<=3.20.3,<4", "scikit-learn>=1.1.1,<2", "scikit-learn-extra>=0.3.0,<1", @@ -46,10 +47,10 @@ dependencies = [ "typing_extensions>=4.7.0", # Telemetry: - "opentelemetry-sdk>=1.16.0", - "opentelemetry-propagator-aws-xray>=1.0.0", - "opentelemetry-exporter-otlp>=1.16.0", - "opentelemetry-sdk-extension-aws>=2.0.0", + "opentelemetry-sdk>=1.16.0,<2", + "opentelemetry-propagator-aws-xray>=1.0.0,<2", + "opentelemetry-exporter-otlp>=1.16.0,<2", + "opentelemetry-sdk-extension-aws>=2.0.0,<3", ] [project.urls] diff --git a/tests/test_deprecations.py b/tests/test_deprecations.py index fc2e83a6c..cf53c7625 100644 --- a/tests/test_deprecations.py +++ b/tests/test_deprecations.py @@ -117,3 +117,10 @@ def test_deprecated_surrogate_registration(): with pytest.raises(DeprecationError): register_custom_architecture() + + +def test_deprecated_surrogate_access(): + """Public attribute access to the surrogate model raises a warning.""" + recommender = BotorchRecommender() + with pytest.warns(DeprecationWarning): + recommender.surrogate_model diff --git a/tests/test_surrogate.py b/tests/test_surrogate.py new file mode 100644 index 000000000..5f6869bc7 --- /dev/null +++ b/tests/test_surrogate.py @@ -0,0 +1,26 @@ +"""Surrogate tests.""" + +from unittest.mock import patch + +from baybe.recommenders.pure.nonpredictive.sampling import RandomRecommender +from baybe.surrogates.gaussian_process.core import GaussianProcessSurrogate +from baybe.utils.dataframe import add_fake_results + + +@patch.object(GaussianProcessSurrogate, "_fit") +def test_caching(patched, searchspace, objective): + """A second fit call with the same context does not trigger retraining.""" + # Prepare the setting + measurements = RandomRecommender().recommend(3, searchspace, objective) + add_fake_results(measurements, objective.targets) + surrogate = GaussianProcessSurrogate() + + # First call + surrogate.fit(searchspace, objective, measurements) + patched.assert_called() + + patched.reset_mock() + + # Second call + surrogate.fit(searchspace, objective, measurements) + patched.assert_not_called()