Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: make FlattenObservationWrapper also flatten next_obs #115

Merged
merged 2 commits into from
Sep 12, 2024

Conversation

JesseSilverberg
Copy link
Contributor

What?

This PR modifies FlattenObservationWrapper to also perform the same flattening transform on next_obs (right now, the wrapper only modifies the current observation).

Why?

Currently, running an algorithm that uses a replay buffer on an environment with non-flat observations will fail (e.g., DQN + gymnax/freeway). The error is one of shapes, caused by FlattenObservationWrapper not flattening extras['next_obs'].

How?

Since this change involved duplicating some logic, I also took the liberty of moving a repetitive set of two lines to its own function.

Extra

Flagging that I'm not familiar enough with Stoix to know whether if 'next_obs' in timestep.extras: is the appropriate check/assumption. It looks like .extras will always be there, so I don't think this will cause issues, but wanted to bring it to your attention.

@EdanToledo
Copy link
Owner

Hello, thanks so much for this. The code looks great but could you run the pre-commit to format the code. You simply need to install the dev requirements using pip install -e .[dev] in the directory where you cloned the repo. or you can simply install the requirements-dev.txt. Then you need to run pre-commit run --all. That should format the code and point out if theres any things you are missing.

@JesseSilverberg
Copy link
Contributor Author

Oops, sorry about that. Checks should be passing now.

@EdanToledo EdanToledo merged commit 6125d36 into EdanToledo:main Sep 12, 2024
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[BUG] Replay Buffer + Non-Flat Observations Fail
2 participants