Skip to content

Commit

Permalink
Replace deprecated botorch method to remove warning
Browse files Browse the repository at this point in the history
  • Loading branch information
nzw0301 committed Sep 24, 2023
1 parent a920a91 commit 7c4fd0a
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions optuna/terminator/improvement/gp/botorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Optional

import numpy as np
from packaging import version

from optuna._imports import try_import
from optuna.distributions import _is_distribution_log
Expand All @@ -16,15 +17,20 @@


with try_import() as _imports:
from botorch.fit import fit_gpytorch_model
import botorch
from botorch.models import SingleTaskGP
from botorch.models.transforms import Normalize
from botorch.models.transforms import Standardize
import gpytorch
import torch

if version.parse(botorch.version.version) < version.parse("0.8.0"):
from botorch.fit import fit_gpytorch_model as fit_gpytorch_mll
else:
from botorch.fit import fit_gpytorch_mll

__all__ = [
"fit_gpytorch_model",
"fit_gpytorch_mll",
"SingleTaskGP",
"Normalize",
"Standardize",
Expand Down Expand Up @@ -61,7 +67,7 @@ def fit(

mll = gpytorch.mlls.ExactMarginalLogLikelihood(self._gp.likelihood, self._gp)

fit_gpytorch_model(mll)
fit_gpytorch_mll(mll)

def predict_mean_std(
self,
Expand Down

0 comments on commit 7c4fd0a

Please sign in to comment.