From ececba78bd315eccfa8e6a8dae7f382489c59562 Mon Sep 17 00:00:00 2001 From: Rodrigo de Salvo Braz Date: Mon, 4 Nov 2024 11:32:30 -0800 Subject: [PATCH] Introduce LearningLogger Summary: Introduce protocol LearningLogger for logging with `offline_learning`. Reviewed By: PavlosApo Differential Revision: D65361620 fbshipit-source-id: 9fe9c9a4752cb549185c163d37c282a90d55921a --- .../train_and_eval/learning_logger.py | 26 +++++++++++++++++++ .../offline_learning_and_evaluation.py | 8 ++++-- .../train_and_eval/online_learning.py | 1 + 3 files changed, 33 insertions(+), 2 deletions(-) create mode 100644 pearl/utils/functional_utils/train_and_eval/learning_logger.py diff --git a/pearl/utils/functional_utils/train_and_eval/learning_logger.py b/pearl/utils/functional_utils/train_and_eval/learning_logger.py new file mode 100644 index 00000000..27018a78 --- /dev/null +++ b/pearl/utils/functional_utils/train_and_eval/learning_logger.py @@ -0,0 +1,26 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +# pyre-strict + +from typing import Any, Dict, Protocol + + +class LearningLogger(Protocol): + """Protocol for a learning logger. + A learning logger is a callable that takes in a dictionary of results and a step number. + It can be used to log the results of a learning process to a database or a file. + Args: + results: A dictionary of results. + step: The current step of the learning process. + prefix: A prefix to add to the logged results. + """ + + def __call__(self, results: Dict[str, Any], step: int, prefix: str = "") -> None: + pass + + +def null_learning_logger(results: Dict[str, str], step: int, prefix: str = "") -> None: + """ + A null learning logger that does nothing. + """ + pass diff --git a/pearl/utils/functional_utils/train_and_eval/offline_learning_and_evaluation.py b/pearl/utils/functional_utils/train_and_eval/offline_learning_and_evaluation.py index ecad09e2..bab8c600 100644 --- a/pearl/utils/functional_utils/train_and_eval/offline_learning_and_evaluation.py +++ b/pearl/utils/functional_utils/train_and_eval/offline_learning_and_evaluation.py @@ -21,6 +21,10 @@ from pearl.replay_buffers.transition import TransitionBatch from pearl.utils.functional_utils.experimentation.set_seed import set_seed from pearl.utils.functional_utils.requests_get import requests_get +from pearl.utils.functional_utils.train_and_eval.learning_logger import ( + LearningLogger, + null_learning_logger, +) from pearl.utils.functional_utils.train_and_eval.online_learning import run_episode from pearl.utils.instantiations.spaces.discrete_action import DiscreteActionSpace @@ -131,6 +135,7 @@ def offline_learning( offline_agent: PearlAgent, data_buffer: ReplayBuffer, training_epochs: int = 1000, + logger: LearningLogger = null_learning_logger, seed: int = 100, ) -> None: """ @@ -153,8 +158,7 @@ def offline_learning( batch = data_buffer.sample(offline_agent.policy_learner.batch_size) assert isinstance(batch, TransitionBatch) loss = offline_agent.learn_batch(batch=batch) - if i % 500 == 0: - print("training epoch", i, "training loss", loss) + logger(loss, i, "training") def offline_evaluation( diff --git a/pearl/utils/functional_utils/train_and_eval/online_learning.py b/pearl/utils/functional_utils/train_and_eval/online_learning.py index 62e7dcb9..77441b3e 100644 --- a/pearl/utils/functional_utils/train_and_eval/online_learning.py +++ b/pearl/utils/functional_utils/train_and_eval/online_learning.py @@ -80,6 +80,7 @@ def online_learning( print_every_x_steps: Optional[int] = None, seed: Optional[int] = None, record_period: int = 1, + # TODO: use LearningLogger similarly to offline_learning ) -> Dict[str, Any]: """ Performs online learning for a number of episodes.