diff --git a/README.md b/README.md index 42e8b9c8..f5bab8c3 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,7 @@ Optuna-Integration API reference is [here](https://optuna-integration.readthedoc * [Catalyst](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#catalyst) ([example](https://github.com/optuna/optuna-examples/blob/main/pytorch/catalyst_simple.py)) * [Chainer](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#chainer) ([example](https://github.com/optuna/optuna-examples/tree/main/chainer/chainer_integration.py)) * [ChainerMN](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#chainermn) ([example](https://github.com/optuna/optuna-examples/tree/main/chainer/chainermn_simple.py)) +* FastAI ([V1](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#fastaiv1) ([example](https://github.com/optuna/optuna-examples/tree/main/fastai/fastaiv1_simple.py)), ([V2](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#fastaiv2) ([example]https://github.com/optuna/optuna-examples/tree/main/fastai/fastaiv2_simple.py))) * [Keras](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#keras) ([example](https://github.com/optuna/optuna-examples/tree/main/keras)) * [MXNet](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#mxnet) ([example](https://github.com/optuna/optuna-examples/tree/main/mxnet)) * [SHAP](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#shap) diff --git a/docs/source/reference/index.rst b/docs/source/reference/index.rst index cad08a07..cfbb740a 100644 --- a/docs/source/reference/index.rst +++ b/docs/source/reference/index.rst @@ -46,6 +46,17 @@ Chainer optuna.integration.ChainerPruningExtension optuna.integration.ChainerMNStudy +fast.ai +------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + optuna.integration.FastAIV1PruningCallback + optuna.integration.FastAIV2PruningCallback + optuna.integration.FastAIPruningCallback + Keras ----- diff --git a/optuna_integration/fastaiv1.py b/optuna_integration/fastaiv1.py new file mode 100644 index 00000000..35729b1b --- /dev/null +++ b/optuna_integration/fastaiv1.py @@ -0,0 +1,82 @@ +from typing import Any + +import optuna +from optuna._deprecated import deprecated_class +from packaging import version + +from optuna_integration._imports import try_import + + +with try_import() as _imports: + import fastai + + if version.parse(fastai.__version__) >= version.parse("2.0.0"): + raise ImportError( + f"You don't have fastai V1 installed! Fastai version: {fastai.__version__}" + ) + + from fastai.basic_train import Learner + from fastai.callbacks import TrackerCallback + +if not _imports.is_successful(): + TrackerCallback = object # NOQA + + +@deprecated_class("2.4.0", "4.0.0") +class FastAIV1PruningCallback(TrackerCallback): + """FastAI callback to prune unpromising trials for fastai. + + .. note:: + This callback is for fastai<2.0. + + See `the example `__ + if you want to add a pruning callback which monitors validation loss of a ``Learner``. + + Example: + + Register a pruning callback to ``learn.fit`` and ``learn.fit_one_cycle``. + + .. code:: + + learn.fit(n_epochs, callbacks=[FastAIPruningCallback(learn, trial, "valid_loss")]) + learn.fit_one_cycle( + n_epochs, + cyc_len, + max_lr, + callbacks=[FastAIPruningCallback(learn, trial, "valid_loss")], + ) + + + Args: + learn: + `fastai.basic_train.Learner `_. + trial: + A :class:`~optuna.trial.Trial` corresponding to the current + evaluation of the objective function. + monitor: + An evaluation metric for pruning, e.g. ``valid_loss`` and ``Accuracy``. + Please refer to `fastai.callbacks.TrackerCallback reference + `_ for further + details. + """ + + def __init__(self, learn: "Learner", trial: optuna.trial.Trial, monitor: str) -> None: + super().__init__(learn, monitor) + + _imports.check() + + self._trial = trial + + def on_epoch_end(self, epoch: int, **kwargs: Any) -> None: + value = self.get_monitor_value() + if value is None: + return + + # This conversion is necessary to avoid problems reported in issues. + # - https://github.com/optuna/optuna/issue/642 + # - https://github.com/optuna/optuna/issue/655. + self._trial.report(float(value), step=epoch) + if self._trial.should_prune(): + message = "Trial was pruned at epoch {}.".format(epoch) + raise optuna.TrialPruned(message) diff --git a/optuna_integration/fastaiv2.py b/optuna_integration/fastaiv2.py new file mode 100644 index 00000000..a5d5748c --- /dev/null +++ b/optuna_integration/fastaiv2.py @@ -0,0 +1,80 @@ +import optuna +from packaging import version + +from optuna_integration._imports import try_import + + +with try_import() as _imports: + import fastai + + if version.parse(fastai.__version__) < version.parse("2.0.0"): + raise ImportError( + f"You don't have fastai V2 installed! Fastai version: {fastai.__version__}" + ) + + from fastai.callback.core import CancelFitException + from fastai.callback.tracker import TrackerCallback + +if not _imports.is_successful(): + TrackerCallback = object # NOQA + + +class FastAIV2PruningCallback(TrackerCallback): + """FastAI callback to prune unpromising trials for fastai. + + .. note:: + This callback is for fastai>=2.0. + + See `the example `__ + if you want to add a pruning callback which monitors validation loss of a ``Learner``. + + Example: + + Register a pruning callback to ``learn.fit`` and ``learn.fit_one_cycle``. + + .. code:: + + learn = cnn_learner(dls, resnet18, metrics=[error_rate]) + learn.fit(n_epochs, cbs=[FastAIPruningCallback(trial)]) # Monitor "valid_loss" + learn.fit_one_cycle( + n_epochs, + lr_max, + cbs=[FastAIPruningCallback(trial, monitor="error_rate")], # Monitor "error_rate" + ) + + + Args: + trial: + A :class:`~optuna.trial.Trial` corresponding to the current + evaluation of the objective function. + monitor: + An evaluation metric for pruning, e.g. ``valid_loss`` or ``accuracy``. + Please refer to `fastai.callback.TrackerCallback reference + `_ for further + details. + """ + + # Implementation notes: it's a subclass of TrackerCallback to benefit from it. For example, + # when to run (after the Recorder callback), when not to (like with lr_find), etc. + + def __init__(self, trial: optuna.Trial, monitor: str = "valid_loss"): + super().__init__(monitor=monitor) + _imports.check() + self.trial = trial + + def after_epoch(self) -> None: + super().after_epoch() + # self.idx is set by TrackTrackerCallback + self.trial.report(self.recorder.final_record[self.idx], step=self.epoch) + + if self.trial.should_prune(): + raise CancelFitException() + + def after_fit(self) -> None: + super().after_fit() + if self.trial.should_prune(): + raise optuna.TrialPruned(f"Trial was pruned at epoch {self.epoch}.") + + +FastAIPruningCallback = FastAIV2PruningCallback diff --git a/pyproject.toml b/pyproject.toml index 5e2d4b14..2e082a12 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,7 @@ document = [ ] all = [ "catalyst", + "fastai", "mxnet", "shap", "skorch", diff --git a/setup.cfg b/setup.cfg index 14512595..193254d4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -26,3 +26,10 @@ no_implicit_reexport = True ignore_missing_imports = True exclude = venv|build|docs + +[mypy-optuna_integration.chainer.*] +ignore_errors = True + +[mypy-tests.test_chainer.*] +ignore_errors = True + diff --git a/tests/test_fastaiv2.py b/tests/test_fastaiv2.py new file mode 100644 index 00000000..f2f1acf9 --- /dev/null +++ b/tests/test_fastaiv2.py @@ -0,0 +1,58 @@ +from typing import Any + +import optuna +from optuna.testing.pruners import DeterministicPruner +import pytest + +from optuna_integration._imports import try_import +from optuna_integration.fastaiv2 import FastAIV2PruningCallback + + +with try_import(): + from fastai.data.core import DataLoader + from fastai.data.core import DataLoaders + from fastai.learner import Learner + from fastai.metrics import accuracy + import torch + import torch.nn as nn + import torch.nn.functional as F + + +def _generate_dummy_dataset() -> "torch.utils.data.DataLoader": + data = torch.zeros(3, 20, dtype=torch.float32) + target = torch.zeros(3, dtype=torch.int64) + dataset = torch.utils.data.TensorDataset(data, target) + return DataLoader(dataset, batch_size=1) + + +@pytest.fixture(scope="session") +def tmpdir(tmpdir_factory: Any) -> Any: + return tmpdir_factory.mktemp("fastai_integration_test") + + +def test_fastai_pruning_callback(tmpdir: Any) -> None: + train_loader = _generate_dummy_dataset() + test_loader = _generate_dummy_dataset() + + data = DataLoaders(train_loader, test_loader, path=tmpdir) + + def objective(trial: optuna.trial.Trial) -> float: + model = nn.Sequential(nn.Linear(20, 1), nn.Sigmoid()) + learn = Learner( + data, + model, + loss_func=F.nll_loss, + metrics=[accuracy], + ) + learn.fit(1, cbs=FastAIV2PruningCallback(trial)) + + return 1.0 + + study = optuna.create_study(pruner=DeterministicPruner(True)) + study.optimize(objective, n_trials=1) + assert study.trials[0].state == optuna.trial.TrialState.PRUNED + + study = optuna.create_study(pruner=DeterministicPruner(False)) + study.optimize(objective, n_trials=1) + assert study.trials[0].state == optuna.trial.TrialState.COMPLETE + assert study.trials[0].value == 1.0