From 9c9e1af83afcb92527d8ab9b3664d6a65d518864 Mon Sep 17 00:00:00 2001 From: Omar Younis <42100908+younik@users.noreply.github.com> Date: Thu, 31 Oct 2024 20:17:43 +0100 Subject: [PATCH] Add pre-commit workflow & fix remaining pyright issues (#209) * fix remianing issue and add pre-commit to workflow * use python 3.10 --- .github/workflows/pre-commit.yml | 20 ++++++++++++++++++++ src/gfn/containers/replay_buffer.py | 6 ++++++ src/gfn/containers/trajectories.py | 17 ++++------------- src/gfn/env.py | 8 ++++---- src/gfn/states.py | 8 ++++---- testing/test_parametrizations_and_losses.py | 2 +- 6 files changed, 39 insertions(+), 22 deletions(-) create mode 100644 .github/workflows/pre-commit.yml diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml new file mode 100644 index 00000000..6b552a0d --- /dev/null +++ b/.github/workflows/pre-commit.yml @@ -0,0 +1,20 @@ +# https://pre-commit.com +# This GitHub Action assumes that the repo contains a valid .pre-commit-config.yaml file. +--- +name: pre-commit +on: [push] + +permissions: + contents: read + +jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v4 + with: + python-version: '3.10' + - run: pip install .[all] + - run: pre-commit --version + - run: pre-commit run --all-files \ No newline at end of file diff --git a/src/gfn/containers/replay_buffer.py b/src/gfn/containers/replay_buffer.py index 3b187430..23bc1065 100644 --- a/src/gfn/containers/replay_buffer.py +++ b/src/gfn/containers/replay_buffer.py @@ -180,6 +180,12 @@ def add(self, training_objects: Transitions | Trajectories | tuple[States]): # Our buffer is full and we will prioritize diverse, high reward additions. else: + if ( + self.training_objects.log_rewards is None + or training_objects.log_rewards is None + ): + raise ValueError("log_rewards must be defined for prioritized replay.") + # Sort the incoming elements by their logrewards. ix = torch.argsort(training_objects.log_rewards, descending=True) training_objects = training_objects[ix] diff --git a/src/gfn/containers/trajectories.py b/src/gfn/containers/trajectories.py index f9e8d87f..5feb665a 100644 --- a/src/gfn/containers/trajectories.py +++ b/src/gfn/containers/trajectories.py @@ -15,11 +15,6 @@ from gfn.utils.common import has_log_probs -def is_tensor(t) -> bool: - """Checks whether t is a torch.Tensor instance.""" - return isinstance(t, torch.Tensor) - - # TODO: remove env from this class? class Trajectories(Container): """Container for complete trajectories (starting in $s_0$ and ending in $s_f$). @@ -113,7 +108,7 @@ def __init__( ) else: log_probs = torch.full(size=(0, 0), fill_value=0, dtype=torch.float) - self.log_probs = log_probs + self.log_probs: torch.Tensor = log_probs self.estimator_outputs = estimator_outputs if self.estimator_outputs is not None: @@ -187,7 +182,7 @@ def __getitem__(self, index: int | Sequence[int]) -> Trajectories: log_rewards = ( self._log_rewards[index] if self._log_rewards is not None else None ) - if is_tensor(self.estimator_outputs): + if self.estimator_outputs is not None: # TODO: Is there a safer way to index self.estimator_outputs for # for n-dimensional estimator outputs? # @@ -292,13 +287,9 @@ def extend(self, other: Trajectories) -> None: # Either set, or append, estimator outputs if they exist in the submitted # trajectory. - if self.estimator_outputs is None and isinstance( - other.estimator_outputs, torch.Tensor - ): + if self.estimator_outputs is None and other.estimator_outputs is not None: self.estimator_outputs = other.estimator_outputs - elif isinstance(self.estimator_outputs, torch.Tensor) and isinstance( - other.estimator_outputs, torch.Tensor - ): + elif self.estimator_outputs is not None and other.estimator_outputs is not None: batch_shape = self.actions.batch_shape n_bs = len(batch_shape) diff --git a/src/gfn/env.py b/src/gfn/env.py index fcfcc694..7a60e8ec 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -54,8 +54,8 @@ def __init__( assert s0.shape == state_shape if sf is None: sf = torch.full(s0.shape, -float("inf")).to(self.device) - assert sf.shape == state_shape - self.sf = sf + self.sf: torch.Tensor = sf + assert self.sf.shape == state_shape self.state_shape = state_shape self.action_shape = action_shape self.dummy_action = dummy_action @@ -381,11 +381,11 @@ def __init__( # The default dummy action is -1. if dummy_action is None: - dummy_action = torch.tensor([-1], device=device) + dummy_action: torch.Tensor = torch.tensor([-1], device=device) # The default exit action index is the final element of the action space. if exit_action is None: - exit_action = torch.tensor([n_actions - 1], device=device) + exit_action: torch.Tensor = torch.tensor([n_actions - 1], device=device) assert s0.shape == state_shape assert dummy_action.shape == action_shape diff --git a/src/gfn/states.py b/src/gfn/states.py index 02242f10..c95ac91d 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -337,11 +337,11 @@ def __init__( dtype=torch.bool, device=self.__class__.device, ) - assert forward_masks.shape == (*self.batch_shape, self.n_actions) - assert backward_masks.shape == (*self.batch_shape, self.n_actions - 1) - self.forward_masks = forward_masks - self.backward_masks = backward_masks + self.forward_masks: torch.Tensor = forward_masks + self.backward_masks: torch.Tensor = backward_masks + assert self.forward_masks.shape == (*self.batch_shape, self.n_actions) + assert self.backward_masks.shape == (*self.batch_shape, self.n_actions - 1) def clone(self) -> States: """Returns a clone of the current instance.""" diff --git a/testing/test_parametrizations_and_losses.py b/testing/test_parametrizations_and_losses.py index 20431d23..edaece25 100644 --- a/testing/test_parametrizations_and_losses.py +++ b/testing/test_parametrizations_and_losses.py @@ -19,7 +19,7 @@ BoxPFMLP, ) from gfn.modules import DiscretePolicyEstimator, ScalarEstimator -from gfn.utils.modules import DiscreteUniform, MLP, Tabular +from gfn.utils.modules import MLP, DiscreteUniform, Tabular N = 10 # Number of trajectories from sample_trajectories (changes tests globally).