Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: switch sebublba to using shard_map like mava #127

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

sash-a
Copy link

@sash-a sash-a commented Nov 3, 2024

What?

Some mava upgrades, as promised 😄
Quite a few nice changes by upgrading to shard_map over pmap here, which avoids some unnecessary device transfers. I found these in mava using transfer guard, quite a nice tool.

Why?

To make sebulba go brrr

How?

  • Switch to shard_map
  • Remove the default actor device - this causing some unnecessary transfers in the pipeline
  • Stop using flax replicate/unreplicate rather explicitly put
  • Move a block_until_ready from the params source to the learner. I think the unreplicate that in the learner before was doing this, without this block we get weird and undefined seg faults
  • One issue I see is that we're now passing the same key to all the learners, we are actually doing this in mava also and I realize it is a minor issue, I'm not entirely sure how to fix it, I tried quickly to switch the sharding for the key to the data sharding which I think should fix it, but it didn't...it's Sunday, hopefully I have time to look at it during the week or if you could find a solution that would be awesome

NOTE:

This is very much not benchmarked. I pulled in Mava's changes in about an hour and I tested it locally and it solves cartpole, but I haven't checked on a TPU or with a harder env

@EdanToledo
Copy link
Owner

Thanks so much, I'll try review this and test it tomorrow on a GPU.

@@ -684,7 +706,7 @@ def run_experiment(_config: DictConfig) -> float:
)

# Get initial parameters
initial_params = unreplicate(learner_state.params)
initial_params = jax.device_put(learner_state.params, actor_devices[0])
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only thing I am not sure of. Why do we put the initial params on the first actor device instead of the cpu or something.

@EdanToledo
Copy link
Owner

I just did a comparison, and it seems like sebulba on main is faster. Looking at all the timing statistics, its the pipeline that is slowing things down. Everything else on this branch is faster. I am trying this on a 2-gpu system. I'll try figure out why its slow.

@EdanToledo
Copy link
Owner

I just did a comparison, and it seems like sebulba on main is faster. Looking at all the timing statistics, its the pipeline that is slowing things down. Everything else on this branch is faster. I am trying this on a 2-gpu system. I'll try figure out why its slow.

So basically, because the actors are faster now, the pipeline fills up faster and then causes the insertions to slow down a lot i think because there is more waiting and blocking which i imagine increases the overheads. I'm not sure on the solution to it though.

@@ -809,9 +809,9 @@ def run_experiment(_config: DictConfig) -> float:
logger.log(train_metrics, t, eval_step, LogEvent.TRAIN)

# Evaluate the current model and log the metrics
learner_state_cpu = jax.device_get(learner_state)
eval_learner_state = jax.device_put(learner_state, evaluator_device)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think CPU for this is best because it should never block actors or learners, but not sure because it would be much faster on an accelerator 🤔

@sash-a
Copy link
Author

sash-a commented Nov 4, 2024

So basically, because the actors are faster now, the pipeline fills up faster and then causes the insertions to slow down a lot i think because there is more waiting and blocking which i imagine increases the overheads. I'm not sure on the solution to it though.

I think faster actors should almost always be better, maybe it just needs some hyper parameter tuning?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants