diff --git a/optuna/terminator/improvement/gp/botorch.py b/optuna/terminator/improvement/gp/botorch.py index 326cb452b4..0f43573fd3 100644 --- a/optuna/terminator/improvement/gp/botorch.py +++ b/optuna/terminator/improvement/gp/botorch.py @@ -3,6 +3,7 @@ from typing import Optional import numpy as np +from packaging import version from optuna._imports import try_import from optuna.distributions import _is_distribution_log @@ -16,15 +17,20 @@ with try_import() as _imports: - from botorch.fit import fit_gpytorch_model + import botorch from botorch.models import SingleTaskGP from botorch.models.transforms import Normalize from botorch.models.transforms import Standardize import gpytorch import torch + if version.parse(botorch.version.version) < version.parse("0.8.0"): + from botorch.fit import fit_gpytorch_model as fit_gpytorch_mll + else: + from botorch.fit import fit_gpytorch_mll + __all__ = [ - "fit_gpytorch_model", + "fit_gpytorch_mll", "SingleTaskGP", "Normalize", "Standardize", @@ -61,7 +67,7 @@ def fit( mll = gpytorch.mlls.ExactMarginalLogLikelihood(self._gp.likelihood, self._gp) - fit_gpytorch_model(mll) + fit_gpytorch_mll(mll) def predict_mean_std( self,