Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Replay Buffer + Non-Flat Observations Fail #116

Closed
JesseSilverberg opened this issue Sep 11, 2024 · 1 comment · Fixed by #115
Closed

[BUG] Replay Buffer + Non-Flat Observations Fail #116

JesseSilverberg opened this issue Sep 11, 2024 · 1 comment · Fixed by #115
Labels
bug Something isn't working

Comments

@JesseSilverberg
Copy link
Contributor

Describe the bug

Trying to run an algorithm that uses a replay buffer with an environment that has non-flat observations yields a shape error:

Traceback (most recent call last):
  File "/home/REDACTED/projects/Stoix/stoix/systems/q_learning/ff_dqn.py", line 572, in hydra_entry_point
    eval_performance = run_experiment(cfg)
  File "/home/REDACTED/projects/Stoix/stoix/systems/q_learning/ff_dqn.py", line 443, in run_experiment
    learn, eval_q_network, learner_state = learner_setup(env, (key, q_net_key), config)
  File "/home/REDACTED/projects/Stoix/stoix/systems/q_learning/ff_dqn.py", line 414, in learner_setup
    env_states, timesteps, keys, buffer_states = warmup(
  File "/home/REDACTED/projects/Stoix/stoix/systems/q_learning/ff_dqn.py", line 84, in warmup
    buffer_states = buffer_add_fn(buffer_states, traj_batch)
  File "/home/REDACTED/projects/Stoix/venv/lib/python3.10/site-packages/flashbax/buffers/item_buffer.py", line 96, in add_fn
    return buffer.add(state, flattened_batch)
  File "/home/REDACTED/projects/Stoix/venv/lib/python3.10/site-packages/flashbax/buffers/trajectory_buffer.py", line 149, in add
    experience = jax.tree_util.tree_map(
  File "/home/REDACTED/projects/Stoix/venv/lib/python3.10/site-packages/flashbax/buffers/trajectory_buffer.py", line 150, in <lambda>
    lambda experience_field, batch_field: experience_field.at[:, indices].set(
  File "/home/REDACTED/projects/Stoix/venv/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py", line 500, in set
    return scatter._scatter_update(self.array, self.index, values, lax.scatter,
  File "/home/REDACTED/projects/Stoix/venv/lib/python3.10/site-packages/jax/_src/ops/scatter.py", line 76, in _scatter_update
    return _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
  File "/home/REDACTED/projects/Stoix/venv/lib/python3.10/site-packages/jax/_src/ops/scatter.py", line 111, in _scatter_impl
    y = jnp.broadcast_to(y, tuple(indexer.slice_shape))
  File "/home/REDACTED/projects/Stoix/venv/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 2252, in broadcast_to
    return util._broadcast_to(array, shape)
  File "/home/REDACTED/projects/Stoix/venv/lib/python3.10/site-packages/jax/_src/numpy/util.py", line 421, in _broadcast_to
    raise ValueError(f"Cannot broadcast to shape with fewer dimensions: {arr_shape=} {shape=}")
ValueError: Cannot broadcast to shape with fewer dimensions: arr_shape=(1, 16384, 10, 10, 7) shape=(1, 16384, 700)

To Reproduce

python stoix/systems/q_learning/ff_dqn.py env=gymnax/freeway

Context (Environment)

This is on the latest version of Stoix (installed a few days ago).

Possible Solution

#115

@JesseSilverberg JesseSilverberg added the bug Something isn't working label Sep 11, 2024
@EdanToledo EdanToledo linked a pull request Sep 12, 2024 that will close this issue
@EdanToledo
Copy link
Owner

Hey, thanks so much for pointing this out. Lets move discussion over to PR but imagine it will be a quick one :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants