Skip to content

Commit

Permalink
Introduce LearningLogger
Browse files Browse the repository at this point in the history
Summary: Introduce protocol LearningLogger for logging with `offline_learning`.

Reviewed By: PavlosApo

Differential Revision: D65361620

fbshipit-source-id: 9fe9c9a4752cb549185c163d37c282a90d55921a
  • Loading branch information
rodrigodesalvobraz authored and facebook-github-bot committed Nov 4, 2024
1 parent c11becd commit ececba7
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 2 deletions.
26 changes: 26 additions & 0 deletions pearl/utils/functional_utils/train_and_eval/learning_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

# pyre-strict

from typing import Any, Dict, Protocol


class LearningLogger(Protocol):
"""Protocol for a learning logger.
A learning logger is a callable that takes in a dictionary of results and a step number.
It can be used to log the results of a learning process to a database or a file.
Args:
results: A dictionary of results.
step: The current step of the learning process.
prefix: A prefix to add to the logged results.
"""

def __call__(self, results: Dict[str, Any], step: int, prefix: str = "") -> None:
pass


def null_learning_logger(results: Dict[str, str], step: int, prefix: str = "") -> None:
"""
A null learning logger that does nothing.
"""
pass
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
from pearl.replay_buffers.transition import TransitionBatch
from pearl.utils.functional_utils.experimentation.set_seed import set_seed
from pearl.utils.functional_utils.requests_get import requests_get
from pearl.utils.functional_utils.train_and_eval.learning_logger import (
LearningLogger,
null_learning_logger,
)
from pearl.utils.functional_utils.train_and_eval.online_learning import run_episode
from pearl.utils.instantiations.spaces.discrete_action import DiscreteActionSpace

Expand Down Expand Up @@ -131,6 +135,7 @@ def offline_learning(
offline_agent: PearlAgent,
data_buffer: ReplayBuffer,
training_epochs: int = 1000,
logger: LearningLogger = null_learning_logger,
seed: int = 100,
) -> None:
"""
Expand All @@ -153,8 +158,7 @@ def offline_learning(
batch = data_buffer.sample(offline_agent.policy_learner.batch_size)
assert isinstance(batch, TransitionBatch)
loss = offline_agent.learn_batch(batch=batch)
if i % 500 == 0:
print("training epoch", i, "training loss", loss)
logger(loss, i, "training")


def offline_evaluation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def online_learning(
print_every_x_steps: Optional[int] = None,
seed: Optional[int] = None,
record_period: int = 1,
# TODO: use LearningLogger similarly to offline_learning
) -> Dict[str, Any]:
"""
Performs online learning for a number of episodes.
Expand Down

0 comments on commit ececba7

Please sign in to comment.