-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: switch sebublba to using shard_map like mava #127
base: main
Are you sure you want to change the base?
Conversation
Thanks so much, I'll try review this and test it tomorrow on a GPU. |
@@ -684,7 +706,7 @@ def run_experiment(_config: DictConfig) -> float: | |||
) | |||
|
|||
# Get initial parameters | |||
initial_params = unreplicate(learner_state.params) | |||
initial_params = jax.device_put(learner_state.params, actor_devices[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the only thing I am not sure of. Why do we put the initial params on the first actor device instead of the cpu or something.
I just did a comparison, and it seems like sebulba on main is faster. Looking at all the timing statistics, its the pipeline that is slowing things down. Everything else on this branch is faster. I am trying this on a 2-gpu system. I'll try figure out why its slow. |
So basically, because the actors are faster now, the pipeline fills up faster and then causes the insertions to slow down a lot i think because there is more waiting and blocking which i imagine increases the overheads. I'm not sure on the solution to it though. |
@@ -809,9 +809,9 @@ def run_experiment(_config: DictConfig) -> float: | |||
logger.log(train_metrics, t, eval_step, LogEvent.TRAIN) | |||
|
|||
# Evaluate the current model and log the metrics | |||
learner_state_cpu = jax.device_get(learner_state) | |||
eval_learner_state = jax.device_put(learner_state, evaluator_device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think CPU for this is best because it should never block actors or learners, but not sure because it would be much faster on an accelerator 🤔
I think faster actors should almost always be better, maybe it just needs some hyper parameter tuning? |
What?
Some mava upgrades, as promised 😄
Quite a few nice changes by upgrading to
shard_map
overpmap
here, which avoids some unnecessary device transfers. I found these in mava using transfer guard, quite a nice tool.Why?
To make sebulba go brrr
How?
shard_map
put
block_until_ready
from the params source to the learner. I think theunreplicate
that in the learner before was doing this, without this block we get weird and undefined seg faultsNOTE:
This is very much not benchmarked. I pulled in Mava's changes in about an hour and I tested it locally and it solves cartpole, but I haven't checked on a TPU or with a harder env