Skip to content

Commit

Permalink
Merge branch 'main' of github.com:optuna/optuna-integration into test…
Browse files Browse the repository at this point in the history
…s/remove-stale-integration-tests
  • Loading branch information
HideakiImamura committed Nov 17, 2023
2 parents 014b60a + bf16b34 commit f87c5de
Show file tree
Hide file tree
Showing 7 changed files with 240 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Optuna-Integration API reference is [here](https://optuna-integration.readthedoc
* [Catalyst](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#catalyst) ([example](https://github.com/optuna/optuna-examples/blob/main/pytorch/catalyst_simple.py))
* [Chainer](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#chainer) ([example](https://github.com/optuna/optuna-examples/tree/main/chainer/chainer_integration.py))
* [ChainerMN](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#chainermn) ([example](https://github.com/optuna/optuna-examples/tree/main/chainer/chainermn_simple.py))
* FastAI ([V1](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#fastaiv1) ([example](https://github.com/optuna/optuna-examples/tree/main/fastai/fastaiv1_simple.py)), ([V2](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#fastaiv2) ([example]https://github.com/optuna/optuna-examples/tree/main/fastai/fastaiv2_simple.py)))
* [Keras](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#keras) ([example](https://github.com/optuna/optuna-examples/tree/main/keras))
* [MXNet](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#mxnet) ([example](https://github.com/optuna/optuna-examples/tree/main/mxnet))
* [SHAP](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#shap)
Expand Down
11 changes: 11 additions & 0 deletions docs/source/reference/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,17 @@ Chainer
optuna.integration.ChainerPruningExtension
optuna.integration.ChainerMNStudy

fast.ai
-------

.. autosummary::
:toctree: generated/
:nosignatures:

optuna.integration.FastAIV1PruningCallback
optuna.integration.FastAIV2PruningCallback
optuna.integration.FastAIPruningCallback

Keras
-----

Expand Down
82 changes: 82 additions & 0 deletions optuna_integration/fastaiv1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from typing import Any

import optuna
from optuna._deprecated import deprecated_class
from packaging import version

from optuna_integration._imports import try_import


with try_import() as _imports:
import fastai

if version.parse(fastai.__version__) >= version.parse("2.0.0"):
raise ImportError(
f"You don't have fastai V1 installed! Fastai version: {fastai.__version__}"
)

from fastai.basic_train import Learner
from fastai.callbacks import TrackerCallback

if not _imports.is_successful():
TrackerCallback = object # NOQA


@deprecated_class("2.4.0", "4.0.0")
class FastAIV1PruningCallback(TrackerCallback):
"""FastAI callback to prune unpromising trials for fastai.
.. note::
This callback is for fastai<2.0.
See `the example <https://github.com/optuna/optuna-examples/blob/main/
fastai/fastaiv1_simple.py>`__
if you want to add a pruning callback which monitors validation loss of a ``Learner``.
Example:
Register a pruning callback to ``learn.fit`` and ``learn.fit_one_cycle``.
.. code::
learn.fit(n_epochs, callbacks=[FastAIPruningCallback(learn, trial, "valid_loss")])
learn.fit_one_cycle(
n_epochs,
cyc_len,
max_lr,
callbacks=[FastAIPruningCallback(learn, trial, "valid_loss")],
)
Args:
learn:
`fastai.basic_train.Learner <https://fastai1.fast.ai/basic_train.html#Learner>`_.
trial:
A :class:`~optuna.trial.Trial` corresponding to the current
evaluation of the objective function.
monitor:
An evaluation metric for pruning, e.g. ``valid_loss`` and ``Accuracy``.
Please refer to `fastai.callbacks.TrackerCallback reference
<https://fastai1.fast.ai/callbacks.tracker.html#TrackerCallback>`_ for further
details.
"""

def __init__(self, learn: "Learner", trial: optuna.trial.Trial, monitor: str) -> None:
super().__init__(learn, monitor)

_imports.check()

self._trial = trial

def on_epoch_end(self, epoch: int, **kwargs: Any) -> None:
value = self.get_monitor_value()
if value is None:
return

# This conversion is necessary to avoid problems reported in issues.
# - https://github.com/optuna/optuna/issue/642
# - https://github.com/optuna/optuna/issue/655.
self._trial.report(float(value), step=epoch)
if self._trial.should_prune():
message = "Trial was pruned at epoch {}.".format(epoch)
raise optuna.TrialPruned(message)
80 changes: 80 additions & 0 deletions optuna_integration/fastaiv2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import optuna
from packaging import version

from optuna_integration._imports import try_import


with try_import() as _imports:
import fastai

if version.parse(fastai.__version__) < version.parse("2.0.0"):
raise ImportError(
f"You don't have fastai V2 installed! Fastai version: {fastai.__version__}"
)

from fastai.callback.core import CancelFitException
from fastai.callback.tracker import TrackerCallback

if not _imports.is_successful():
TrackerCallback = object # NOQA


class FastAIV2PruningCallback(TrackerCallback):
"""FastAI callback to prune unpromising trials for fastai.
.. note::
This callback is for fastai>=2.0.
See `the example <https://github.com/optuna/optuna-examples/blob/main/
fastai/fastaiv2_simple.py>`__
if you want to add a pruning callback which monitors validation loss of a ``Learner``.
Example:
Register a pruning callback to ``learn.fit`` and ``learn.fit_one_cycle``.
.. code::
learn = cnn_learner(dls, resnet18, metrics=[error_rate])
learn.fit(n_epochs, cbs=[FastAIPruningCallback(trial)]) # Monitor "valid_loss"
learn.fit_one_cycle(
n_epochs,
lr_max,
cbs=[FastAIPruningCallback(trial, monitor="error_rate")], # Monitor "error_rate"
)
Args:
trial:
A :class:`~optuna.trial.Trial` corresponding to the current
evaluation of the objective function.
monitor:
An evaluation metric for pruning, e.g. ``valid_loss`` or ``accuracy``.
Please refer to `fastai.callback.TrackerCallback reference
<https://docs.fast.ai/callback.tracker#TrackerCallback>`_ for further
details.
"""

# Implementation notes: it's a subclass of TrackerCallback to benefit from it. For example,
# when to run (after the Recorder callback), when not to (like with lr_find), etc.

def __init__(self, trial: optuna.Trial, monitor: str = "valid_loss"):
super().__init__(monitor=monitor)
_imports.check()
self.trial = trial

def after_epoch(self) -> None:
super().after_epoch()
# self.idx is set by TrackTrackerCallback
self.trial.report(self.recorder.final_record[self.idx], step=self.epoch)

if self.trial.should_prune():
raise CancelFitException()

def after_fit(self) -> None:
super().after_fit()
if self.trial.should_prune():
raise optuna.TrialPruned(f"Trial was pruned at epoch {self.epoch}.")


FastAIPruningCallback = FastAIV2PruningCallback
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ document = [
]
all = [
"catalyst",
"fastai",
"mxnet",
"shap",
"skorch",
Expand Down
7 changes: 7 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,10 @@ no_implicit_reexport = True

ignore_missing_imports = True
exclude = venv|build|docs

[mypy-optuna_integration.chainer.*]
ignore_errors = True

[mypy-tests.test_chainer.*]
ignore_errors = True

58 changes: 58 additions & 0 deletions tests/test_fastaiv2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import Any

import optuna
from optuna.testing.pruners import DeterministicPruner
import pytest

from optuna_integration._imports import try_import
from optuna_integration.fastaiv2 import FastAIV2PruningCallback


with try_import():
from fastai.data.core import DataLoader
from fastai.data.core import DataLoaders
from fastai.learner import Learner
from fastai.metrics import accuracy
import torch
import torch.nn as nn
import torch.nn.functional as F


def _generate_dummy_dataset() -> "torch.utils.data.DataLoader":
data = torch.zeros(3, 20, dtype=torch.float32)
target = torch.zeros(3, dtype=torch.int64)
dataset = torch.utils.data.TensorDataset(data, target)
return DataLoader(dataset, batch_size=1)


@pytest.fixture(scope="session")
def tmpdir(tmpdir_factory: Any) -> Any:
return tmpdir_factory.mktemp("fastai_integration_test")


def test_fastai_pruning_callback(tmpdir: Any) -> None:
train_loader = _generate_dummy_dataset()
test_loader = _generate_dummy_dataset()

data = DataLoaders(train_loader, test_loader, path=tmpdir)

def objective(trial: optuna.trial.Trial) -> float:
model = nn.Sequential(nn.Linear(20, 1), nn.Sigmoid())
learn = Learner(
data,
model,
loss_func=F.nll_loss,
metrics=[accuracy],
)
learn.fit(1, cbs=FastAIV2PruningCallback(trial))

return 1.0

study = optuna.create_study(pruner=DeterministicPruner(True))
study.optimize(objective, n_trials=1)
assert study.trials[0].state == optuna.trial.TrialState.PRUNED

study = optuna.create_study(pruner=DeterministicPruner(False))
study.optimize(objective, n_trials=1)
assert study.trials[0].state == optuna.trial.TrialState.COMPLETE
assert study.trials[0].value == 1.0

0 comments on commit f87c5de

Please sign in to comment.