-
Notifications
You must be signed in to change notification settings - Fork 32
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' of github.com:optuna/optuna-integration into test…
…s/remove-stale-integration-tests
- Loading branch information
Showing
7 changed files
with
240 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
from typing import Any | ||
|
||
import optuna | ||
from optuna._deprecated import deprecated_class | ||
from packaging import version | ||
|
||
from optuna_integration._imports import try_import | ||
|
||
|
||
with try_import() as _imports: | ||
import fastai | ||
|
||
if version.parse(fastai.__version__) >= version.parse("2.0.0"): | ||
raise ImportError( | ||
f"You don't have fastai V1 installed! Fastai version: {fastai.__version__}" | ||
) | ||
|
||
from fastai.basic_train import Learner | ||
from fastai.callbacks import TrackerCallback | ||
|
||
if not _imports.is_successful(): | ||
TrackerCallback = object # NOQA | ||
|
||
|
||
@deprecated_class("2.4.0", "4.0.0") | ||
class FastAIV1PruningCallback(TrackerCallback): | ||
"""FastAI callback to prune unpromising trials for fastai. | ||
.. note:: | ||
This callback is for fastai<2.0. | ||
See `the example <https://github.com/optuna/optuna-examples/blob/main/ | ||
fastai/fastaiv1_simple.py>`__ | ||
if you want to add a pruning callback which monitors validation loss of a ``Learner``. | ||
Example: | ||
Register a pruning callback to ``learn.fit`` and ``learn.fit_one_cycle``. | ||
.. code:: | ||
learn.fit(n_epochs, callbacks=[FastAIPruningCallback(learn, trial, "valid_loss")]) | ||
learn.fit_one_cycle( | ||
n_epochs, | ||
cyc_len, | ||
max_lr, | ||
callbacks=[FastAIPruningCallback(learn, trial, "valid_loss")], | ||
) | ||
Args: | ||
learn: | ||
`fastai.basic_train.Learner <https://fastai1.fast.ai/basic_train.html#Learner>`_. | ||
trial: | ||
A :class:`~optuna.trial.Trial` corresponding to the current | ||
evaluation of the objective function. | ||
monitor: | ||
An evaluation metric for pruning, e.g. ``valid_loss`` and ``Accuracy``. | ||
Please refer to `fastai.callbacks.TrackerCallback reference | ||
<https://fastai1.fast.ai/callbacks.tracker.html#TrackerCallback>`_ for further | ||
details. | ||
""" | ||
|
||
def __init__(self, learn: "Learner", trial: optuna.trial.Trial, monitor: str) -> None: | ||
super().__init__(learn, monitor) | ||
|
||
_imports.check() | ||
|
||
self._trial = trial | ||
|
||
def on_epoch_end(self, epoch: int, **kwargs: Any) -> None: | ||
value = self.get_monitor_value() | ||
if value is None: | ||
return | ||
|
||
# This conversion is necessary to avoid problems reported in issues. | ||
# - https://github.com/optuna/optuna/issue/642 | ||
# - https://github.com/optuna/optuna/issue/655. | ||
self._trial.report(float(value), step=epoch) | ||
if self._trial.should_prune(): | ||
message = "Trial was pruned at epoch {}.".format(epoch) | ||
raise optuna.TrialPruned(message) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import optuna | ||
from packaging import version | ||
|
||
from optuna_integration._imports import try_import | ||
|
||
|
||
with try_import() as _imports: | ||
import fastai | ||
|
||
if version.parse(fastai.__version__) < version.parse("2.0.0"): | ||
raise ImportError( | ||
f"You don't have fastai V2 installed! Fastai version: {fastai.__version__}" | ||
) | ||
|
||
from fastai.callback.core import CancelFitException | ||
from fastai.callback.tracker import TrackerCallback | ||
|
||
if not _imports.is_successful(): | ||
TrackerCallback = object # NOQA | ||
|
||
|
||
class FastAIV2PruningCallback(TrackerCallback): | ||
"""FastAI callback to prune unpromising trials for fastai. | ||
.. note:: | ||
This callback is for fastai>=2.0. | ||
See `the example <https://github.com/optuna/optuna-examples/blob/main/ | ||
fastai/fastaiv2_simple.py>`__ | ||
if you want to add a pruning callback which monitors validation loss of a ``Learner``. | ||
Example: | ||
Register a pruning callback to ``learn.fit`` and ``learn.fit_one_cycle``. | ||
.. code:: | ||
learn = cnn_learner(dls, resnet18, metrics=[error_rate]) | ||
learn.fit(n_epochs, cbs=[FastAIPruningCallback(trial)]) # Monitor "valid_loss" | ||
learn.fit_one_cycle( | ||
n_epochs, | ||
lr_max, | ||
cbs=[FastAIPruningCallback(trial, monitor="error_rate")], # Monitor "error_rate" | ||
) | ||
Args: | ||
trial: | ||
A :class:`~optuna.trial.Trial` corresponding to the current | ||
evaluation of the objective function. | ||
monitor: | ||
An evaluation metric for pruning, e.g. ``valid_loss`` or ``accuracy``. | ||
Please refer to `fastai.callback.TrackerCallback reference | ||
<https://docs.fast.ai/callback.tracker#TrackerCallback>`_ for further | ||
details. | ||
""" | ||
|
||
# Implementation notes: it's a subclass of TrackerCallback to benefit from it. For example, | ||
# when to run (after the Recorder callback), when not to (like with lr_find), etc. | ||
|
||
def __init__(self, trial: optuna.Trial, monitor: str = "valid_loss"): | ||
super().__init__(monitor=monitor) | ||
_imports.check() | ||
self.trial = trial | ||
|
||
def after_epoch(self) -> None: | ||
super().after_epoch() | ||
# self.idx is set by TrackTrackerCallback | ||
self.trial.report(self.recorder.final_record[self.idx], step=self.epoch) | ||
|
||
if self.trial.should_prune(): | ||
raise CancelFitException() | ||
|
||
def after_fit(self) -> None: | ||
super().after_fit() | ||
if self.trial.should_prune(): | ||
raise optuna.TrialPruned(f"Trial was pruned at epoch {self.epoch}.") | ||
|
||
|
||
FastAIPruningCallback = FastAIV2PruningCallback |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -53,6 +53,7 @@ document = [ | |
] | ||
all = [ | ||
"catalyst", | ||
"fastai", | ||
"mxnet", | ||
"shap", | ||
"skorch", | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from typing import Any | ||
|
||
import optuna | ||
from optuna.testing.pruners import DeterministicPruner | ||
import pytest | ||
|
||
from optuna_integration._imports import try_import | ||
from optuna_integration.fastaiv2 import FastAIV2PruningCallback | ||
|
||
|
||
with try_import(): | ||
from fastai.data.core import DataLoader | ||
from fastai.data.core import DataLoaders | ||
from fastai.learner import Learner | ||
from fastai.metrics import accuracy | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
|
||
def _generate_dummy_dataset() -> "torch.utils.data.DataLoader": | ||
data = torch.zeros(3, 20, dtype=torch.float32) | ||
target = torch.zeros(3, dtype=torch.int64) | ||
dataset = torch.utils.data.TensorDataset(data, target) | ||
return DataLoader(dataset, batch_size=1) | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def tmpdir(tmpdir_factory: Any) -> Any: | ||
return tmpdir_factory.mktemp("fastai_integration_test") | ||
|
||
|
||
def test_fastai_pruning_callback(tmpdir: Any) -> None: | ||
train_loader = _generate_dummy_dataset() | ||
test_loader = _generate_dummy_dataset() | ||
|
||
data = DataLoaders(train_loader, test_loader, path=tmpdir) | ||
|
||
def objective(trial: optuna.trial.Trial) -> float: | ||
model = nn.Sequential(nn.Linear(20, 1), nn.Sigmoid()) | ||
learn = Learner( | ||
data, | ||
model, | ||
loss_func=F.nll_loss, | ||
metrics=[accuracy], | ||
) | ||
learn.fit(1, cbs=FastAIV2PruningCallback(trial)) | ||
|
||
return 1.0 | ||
|
||
study = optuna.create_study(pruner=DeterministicPruner(True)) | ||
study.optimize(objective, n_trials=1) | ||
assert study.trials[0].state == optuna.trial.TrialState.PRUNED | ||
|
||
study = optuna.create_study(pruner=DeterministicPruner(False)) | ||
study.optimize(objective, n_trials=1) | ||
assert study.trials[0].state == optuna.trial.TrialState.COMPLETE | ||
assert study.trials[0].value == 1.0 |