fix: make FlattenObservationWrapper also flatten next_obs #115
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
What?
This PR modifies
FlattenObservationWrapper
to also perform the same flattening transform onnext_obs
(right now, the wrapper only modifies the current observation).Why?
Currently, running an algorithm that uses a replay buffer on an environment with non-flat observations will fail (e.g., DQN + gymnax/freeway). The error is one of shapes, caused by
FlattenObservationWrapper
not flatteningextras['next_obs']
.How?
Since this change involved duplicating some logic, I also took the liberty of moving a repetitive set of two lines to its own function.
Extra
Flagging that I'm not familiar enough with Stoix to know whether
if 'next_obs' in timestep.extras:
is the appropriate check/assumption. It looks like.extras
will always be there, so I don't think this will cause issues, but wanted to bring it to your attention.