From b3fb4197625468a68ecbdc6d8dabef4e1c06f588 Mon Sep 17 00:00:00 2001 From: minami yoshiki Date: Sat, 11 Nov 2023 22:30:06 +0900 Subject: [PATCH 1/6] =?UTF-8?q?Fetched=20from=C2=A0https://github.com/optu?= =?UTF-8?q?na/optuna/commit/5e5c498ebf94b51c866ee3f74d479ec24b9155b4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- optuna_integration/fastaiv1.py | 82 ++++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 optuna_integration/fastaiv1.py diff --git a/optuna_integration/fastaiv1.py b/optuna_integration/fastaiv1.py new file mode 100644 index 00000000..878f8991 --- /dev/null +++ b/optuna_integration/fastaiv1.py @@ -0,0 +1,82 @@ +from typing import Any + +from packaging import version + +import optuna +from optuna._deprecated import deprecated_class +from optuna._imports import try_import + + +with try_import() as _imports: + import fastai + + if version.parse(fastai.__version__) >= version.parse("2.0.0"): + raise ImportError( + f"You don't have fastai V1 installed! Fastai version: {fastai.__version__}" + ) + + from fastai.basic_train import Learner + from fastai.callbacks import TrackerCallback + +if not _imports.is_successful(): + TrackerCallback = object # NOQA + + +@deprecated_class("2.4.0", "4.0.0") +class FastAIV1PruningCallback(TrackerCallback): + """FastAI callback to prune unpromising trials for fastai. + + .. note:: + This callback is for fastai<2.0. + + See `the example `__ + if you want to add a pruning callback which monitors validation loss of a ``Learner``. + + Example: + + Register a pruning callback to ``learn.fit`` and ``learn.fit_one_cycle``. + + .. code:: + + learn.fit(n_epochs, callbacks=[FastAIPruningCallback(learn, trial, "valid_loss")]) + learn.fit_one_cycle( + n_epochs, + cyc_len, + max_lr, + callbacks=[FastAIPruningCallback(learn, trial, "valid_loss")], + ) + + + Args: + learn: + `fastai.basic_train.Learner `_. + trial: + A :class:`~optuna.trial.Trial` corresponding to the current + evaluation of the objective function. + monitor: + An evaluation metric for pruning, e.g. ``valid_loss`` and ``Accuracy``. + Please refer to `fastai.callbacks.TrackerCallback reference + `_ for further + details. + """ + + def __init__(self, learn: "Learner", trial: optuna.trial.Trial, monitor: str) -> None: + super().__init__(learn, monitor) + + _imports.check() + + self._trial = trial + + def on_epoch_end(self, epoch: int, **kwargs: Any) -> None: + value = self.get_monitor_value() + if value is None: + return + + # This conversion is necessary to avoid problems reported in issues. + # - https://github.com/optuna/optuna/issue/642 + # - https://github.com/optuna/optuna/issue/655. + self._trial.report(float(value), step=epoch) + if self._trial.should_prune(): + message = "Trial was pruned at epoch {}.".format(epoch) + raise optuna.TrialPruned(message) From a7266f6e7d44b55f31519db2d4ef509a1204a767 Mon Sep 17 00:00:00 2001 From: minami yoshiki Date: Sat, 11 Nov 2023 22:32:02 +0900 Subject: [PATCH 2/6] =?UTF-8?q?Fetched=20from=C2=A0https://github.com/optu?= =?UTF-8?q?na/optuna/commit/8ad265b680d8bd786bae703ad979b47929b8ef2d?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- optuna_integration/fastaiv2.py | 80 ++++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 optuna_integration/fastaiv2.py diff --git a/optuna_integration/fastaiv2.py b/optuna_integration/fastaiv2.py new file mode 100644 index 00000000..f9ea5caa --- /dev/null +++ b/optuna_integration/fastaiv2.py @@ -0,0 +1,80 @@ +from packaging import version + +import optuna +from optuna._imports import try_import + + +with try_import() as _imports: + import fastai + + if version.parse(fastai.__version__) < version.parse("2.0.0"): + raise ImportError( + f"You don't have fastai V2 installed! Fastai version: {fastai.__version__}" + ) + + from fastai.callback.core import CancelFitException + from fastai.callback.tracker import TrackerCallback + +if not _imports.is_successful(): + TrackerCallback = object # NOQA + + +class FastAIV2PruningCallback(TrackerCallback): + """FastAI callback to prune unpromising trials for fastai. + + .. note:: + This callback is for fastai>=2.0. + + See `the example `__ + if you want to add a pruning callback which monitors validation loss of a ``Learner``. + + Example: + + Register a pruning callback to ``learn.fit`` and ``learn.fit_one_cycle``. + + .. code:: + + learn = cnn_learner(dls, resnet18, metrics=[error_rate]) + learn.fit(n_epochs, cbs=[FastAIPruningCallback(trial)]) # Monitor "valid_loss" + learn.fit_one_cycle( + n_epochs, + lr_max, + cbs=[FastAIPruningCallback(trial, monitor="error_rate")], # Monitor "error_rate" + ) + + + Args: + trial: + A :class:`~optuna.trial.Trial` corresponding to the current + evaluation of the objective function. + monitor: + An evaluation metric for pruning, e.g. ``valid_loss`` or ``accuracy``. + Please refer to `fastai.callback.TrackerCallback reference + `_ for further + details. + """ + + # Implementation notes: it's a subclass of TrackerCallback to benefit from it. For example, + # when to run (after the Recorder callback), when not to (like with lr_find), etc. + + def __init__(self, trial: optuna.Trial, monitor: str = "valid_loss"): + super().__init__(monitor=monitor) + _imports.check() + self.trial = trial + + def after_epoch(self) -> None: + super().after_epoch() + # self.idx is set by TrackTrackerCallback + self.trial.report(self.recorder.final_record[self.idx], step=self.epoch) + + if self.trial.should_prune(): + raise CancelFitException() + + def after_fit(self) -> None: + super().after_fit() + if self.trial.should_prune(): + raise optuna.TrialPruned(f"Trial was pruned at epoch {self.epoch}.") + + +FastAIPruningCallback = FastAIV2PruningCallback From a180384124ef97a89ed3b01e4b1a6a2c2ae55b73 Mon Sep 17 00:00:00 2001 From: minami yoshiki Date: Sat, 11 Nov 2023 22:33:46 +0900 Subject: [PATCH 3/6] =?UTF-8?q?Fetched=20from=C2=A0https://github.com/optu?= =?UTF-8?q?na/optuna/commit/03b08ebbb99562b1d9e403f22fdba8ff93d4b7e1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_fastaiv2.py | 60 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 tests/test_fastaiv2.py diff --git a/tests/test_fastaiv2.py b/tests/test_fastaiv2.py new file mode 100644 index 00000000..b25ada79 --- /dev/null +++ b/tests/test_fastaiv2.py @@ -0,0 +1,60 @@ +from typing import Any + +import pytest + +import optuna +from optuna._imports import try_import +from optuna.integration import FastAIV2PruningCallback +from optuna.testing.pruners import DeterministicPruner + + +with try_import(): + from fastai.data.core import DataLoader + from fastai.data.core import DataLoaders + from fastai.learner import Learner + from fastai.metrics import accuracy + import torch + import torch.nn as nn + import torch.nn.functional as F + +pytestmark = pytest.mark.integration + + +def _generate_dummy_dataset() -> "torch.utils.data.DataLoader": + data = torch.zeros(3, 20, dtype=torch.float32) + target = torch.zeros(3, dtype=torch.int64) + dataset = torch.utils.data.TensorDataset(data, target) + return DataLoader(dataset, batch_size=1) + + +@pytest.fixture(scope="session") +def tmpdir(tmpdir_factory: Any) -> Any: + return tmpdir_factory.mktemp("fastai_integration_test") + + +def test_fastai_pruning_callback(tmpdir: Any) -> None: + train_loader = _generate_dummy_dataset() + test_loader = _generate_dummy_dataset() + + data = DataLoaders(train_loader, test_loader, path=tmpdir) + + def objective(trial: optuna.trial.Trial) -> float: + model = nn.Sequential(nn.Linear(20, 1), nn.Sigmoid()) + learn = Learner( + data, + model, + loss_func=F.nll_loss, + metrics=[accuracy], + ) + learn.fit(1, cbs=FastAIV2PruningCallback(trial)) + + return 1.0 + + study = optuna.create_study(pruner=DeterministicPruner(True)) + study.optimize(objective, n_trials=1) + assert study.trials[0].state == optuna.trial.TrialState.PRUNED + + study = optuna.create_study(pruner=DeterministicPruner(False)) + study.optimize(objective, n_trials=1) + assert study.trials[0].state == optuna.trial.TrialState.COMPLETE + assert study.trials[0].value == 1.0 From b182d46e8d0315ab3d604134c0064b792e11da6c Mon Sep 17 00:00:00 2001 From: c-bata Date: Tue, 14 Nov 2023 11:40:14 +0900 Subject: [PATCH 4/6] Exclude mypy checks for chainer --- setup.cfg | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/setup.cfg b/setup.cfg index 14512595..193254d4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -26,3 +26,10 @@ no_implicit_reexport = True ignore_missing_imports = True exclude = venv|build|docs + +[mypy-optuna_integration.chainer.*] +ignore_errors = True + +[mypy-tests.test_chainer.*] +ignore_errors = True + From 7cdde95dafb718236e7a753019712662f29925a6 Mon Sep 17 00:00:00 2001 From: minami yoshiki Date: Tue, 14 Nov 2023 23:10:47 +0900 Subject: [PATCH 5/6] On the Optuna Integration side, isolation of Optuna integration modules on the fast.ai module --- README.md | 1 + docs/source/reference/index.rst | 11 +++++++++++ optuna_integration/fastaiv1.py | 6 +++--- optuna_integration/fastaiv2.py | 4 ++-- pyproject.toml | 1 + tests/test_fastaiv2.py | 8 ++++---- 6 files changed, 22 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 42e8b9c8..f5bab8c3 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,7 @@ Optuna-Integration API reference is [here](https://optuna-integration.readthedoc * [Catalyst](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#catalyst) ([example](https://github.com/optuna/optuna-examples/blob/main/pytorch/catalyst_simple.py)) * [Chainer](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#chainer) ([example](https://github.com/optuna/optuna-examples/tree/main/chainer/chainer_integration.py)) * [ChainerMN](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#chainermn) ([example](https://github.com/optuna/optuna-examples/tree/main/chainer/chainermn_simple.py)) +* FastAI ([V1](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#fastaiv1) ([example](https://github.com/optuna/optuna-examples/tree/main/fastai/fastaiv1_simple.py)), ([V2](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#fastaiv2) ([example]https://github.com/optuna/optuna-examples/tree/main/fastai/fastaiv2_simple.py))) * [Keras](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#keras) ([example](https://github.com/optuna/optuna-examples/tree/main/keras)) * [MXNet](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#mxnet) ([example](https://github.com/optuna/optuna-examples/tree/main/mxnet)) * [SHAP](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#shap) diff --git a/docs/source/reference/index.rst b/docs/source/reference/index.rst index cad08a07..cfbb740a 100644 --- a/docs/source/reference/index.rst +++ b/docs/source/reference/index.rst @@ -46,6 +46,17 @@ Chainer optuna.integration.ChainerPruningExtension optuna.integration.ChainerMNStudy +fast.ai +------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + optuna.integration.FastAIV1PruningCallback + optuna.integration.FastAIV2PruningCallback + optuna.integration.FastAIPruningCallback + Keras ----- diff --git a/optuna_integration/fastaiv1.py b/optuna_integration/fastaiv1.py index 878f8991..35729b1b 100644 --- a/optuna_integration/fastaiv1.py +++ b/optuna_integration/fastaiv1.py @@ -1,10 +1,10 @@ from typing import Any -from packaging import version - import optuna from optuna._deprecated import deprecated_class -from optuna._imports import try_import +from packaging import version + +from optuna_integration._imports import try_import with try_import() as _imports: diff --git a/optuna_integration/fastaiv2.py b/optuna_integration/fastaiv2.py index f9ea5caa..a5d5748c 100644 --- a/optuna_integration/fastaiv2.py +++ b/optuna_integration/fastaiv2.py @@ -1,7 +1,7 @@ +import optuna from packaging import version -import optuna -from optuna._imports import try_import +from optuna_integration._imports import try_import with try_import() as _imports: diff --git a/pyproject.toml b/pyproject.toml index ffb8d718..c854e068 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,7 @@ all = [ "allennlp>=2.2.0; python_version<'3.11'", "catalyst", "chainer>=5.0.0", + "fastai", "mpi4py", "mxnet", "shap", diff --git a/tests/test_fastaiv2.py b/tests/test_fastaiv2.py index b25ada79..05718ad9 100644 --- a/tests/test_fastaiv2.py +++ b/tests/test_fastaiv2.py @@ -1,11 +1,11 @@ from typing import Any -import pytest - import optuna -from optuna._imports import try_import -from optuna.integration import FastAIV2PruningCallback from optuna.testing.pruners import DeterministicPruner +import pytest + +from optuna_integration.fastaiv2 import FastAIV2PruningCallback +from optuna_integration._imports import try_import with try_import(): From a67310aee763e32905107f83434aa6ad207689eb Mon Sep 17 00:00:00 2001 From: minami yoshiki Date: Wed, 15 Nov 2023 22:45:10 +0900 Subject: [PATCH 6/6] I followed the review comments. --- tests/test_fastaiv2.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_fastaiv2.py b/tests/test_fastaiv2.py index 05718ad9..f2f1acf9 100644 --- a/tests/test_fastaiv2.py +++ b/tests/test_fastaiv2.py @@ -4,8 +4,8 @@ from optuna.testing.pruners import DeterministicPruner import pytest -from optuna_integration.fastaiv2 import FastAIV2PruningCallback from optuna_integration._imports import try_import +from optuna_integration.fastaiv2 import FastAIV2PruningCallback with try_import(): @@ -17,8 +17,6 @@ import torch.nn as nn import torch.nn.functional as F -pytestmark = pytest.mark.integration - def _generate_dummy_dataset() -> "torch.utils.data.DataLoader": data = torch.zeros(3, 20, dtype=torch.float32)