-
Notifications
You must be signed in to change notification settings - Fork 36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix linter and add pre-commit CI #193
Conversation
assert isinstance( | ||
self.training_objects, (Trajectories, Transitions) | ||
) # TODO: becasue we use last_states... is it correct? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check this: look like self.cutoff_distance>=0
only when we are working with Trajectories
and Transitions
, but there is no real check for it during init. Should this branch work also with tuple of states?
estimator_outputs = self.estimator_outputs | ||
other_estimator_outputs = other.estimator_outputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[these things are required for narrowing types by the checker, which cannot be done on self attributes]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a comment explaining this (and other similar decisions which might confuse a future developer if they don't know the underlying reason)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, make sense
|
||
class States(ABC): | ||
|
||
class States(Container, ABC): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check if this is okay
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having states inherit from Container does not change it's function at all, correct, but adds some handy save/load features and makes it automatically compatible with the buffers, correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, exactly. For example, Container have load/save
which is used in ReplayBuffer.
Hi Omar thanks this is great. Perhaps can you split this into two PRs - one for typing and one for the replay buffer? I think the typing requires more investigation before we undo all of that work, but the replay buffer changes seem very important to resolve soon. Thanks so much! |
Sure, I did it here: #202 |
Thanks Omar for initiating this important work. If we were to indeed remove One way to do that is to create a repo-wide subclass of |
This is a very good idea, indeed. I will do the work in other smaller PRs (drop torch typing, fix other typing errors, ...) for the sake of review. |
Hey Omar just a heads up - doing the review now but looks like you might have some merge conflicts to resolve as well. I like the plan stated above by you two RE: removing |
Hey @josephdviviano, yes, please check #204. The plan is to break this PR into pieces, so I will close it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK awesome PR - thanks - I have a few requests (sorry one of them i s super annoying -- the nature of a PR like this).
Really appreciate your attention to detail.
|
||
from gfn.containers import Trajectories | ||
from gfn.env import Env | ||
from gfn.gflownet.base import TrajectoryBasedGFlowNet | ||
from gfn.modules import GFNModule, ScalarEstimator | ||
|
||
ContributionsTensor = TT["max_len * (1 + max_len) / 2", "n_trajectories"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case, I think the notes will be particularly important.
|
||
class States(ABC): | ||
|
||
class States(Container, ABC): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having states inherit from Container does not change it's function at all, correct, but adds some handy save/load features and makes it automatically compatible with the buffers, correct?
@younik please use my notes from this review in any future PR, I don't want to have to re-do it once more. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies for making you review this big one.
If this PR looks good to you, we can move this forward and add the asserts (and docs) on shape immediately after. Otherwise, we can move the smaller one forward (I will integrate the comments here that are related to that one).
estimator_outputs = self.estimator_outputs | ||
other_estimator_outputs = other.estimator_outputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, make sense
src/gfn/gflownet/base.py
Outdated
@@ -124,8 +131,8 @@ def get_pfs_and_pbs( | |||
fill_value: float = 0.0, | |||
recalculate_all_logprobs: bool = False, | |||
) -> Tuple[ | |||
TT["max_length", "n_trajectories", torch.float], | |||
TT["max_length", "n_trajectories", torch.float], | |||
Tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
||
class States(ABC): | ||
|
||
class States(Container, ABC): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, exactly. For example, Container have load/save
which is used in ReplayBuffer.
I open it again and make it a draft, otherwise, it doesn't get updated. |
Hey @younik - I assume this is now safe to close? |
Yes, I close it. |
Apologies for the big PR; I can split it in two if it is too hassle to review.
This PR adds the pre-commit check on CI and fixes all the problems with linters.
To fix the problems with pyright:
torchtyping
, in favor oftorch.Tensor
only. In this way, we lost shape typing (which is a nice feature), but we can have static typing checks. I suggest documenting well the functions with the expected shape and adding assert on shape when necessary.ReplayBuffer.add
. Please check the fix as it is different from Small fixes #192ReplayBuffer
itself is a bit problematic because it can accept Trajectories, Transitions, and tuples of States. While Trajectories and Transitions inherit from Container, States doesn't. Also, the code has a bunch of if-else to handle the cases differently. Should States (or a new class StatesTuple) inherit from Container, and require for Container objects a complete interface so ReplayBuffer can be agnostic to the underline object? Or should we have different subclasses of ReplayBuffer? For simplicity, atm, I proposed to make States inherit from Container, but I suggest we think about it (e.g. an object tuple of states may be more appropriate to inherit from Container). I can address this in a future PR.