-
Notifications
You must be signed in to change notification settings - Fork 269
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fsdp=16 model=16 gbs=16 should work on 256 chips #773
Comments
Hi @samos123 , you can use the input dispatcher: axlearn/axlearn/common/input_tf_data.py Lines 1165 to 1167 in ac63eef
axlearn/axlearn/common/input_dispatch.py Lines 17 to 33 in ac63eef
Some hosts will produce padding feeds which will be dropped during input dispatch. I have some ideas to make this a bit simpler soon, but this should unblock you for now. |
Maybe I'm misunderstanding the code... but input_dispatcher is None for the fuji models, so wouldn't it default to InputDispatcher already? @markblee |
Are you saying I should create a custom InputDispatcher and pass that instead? That may make sense. Looking into that further. |
Would these be the right settings for fsdp=16 and model=16 gbs=16 on v6e-256?
|
Currently in Fuji models this is set:
This is what's inside input_tf_data.batch function:
So I suspect I need to modify the batch function directly. |
@samos123 please read through the input logic, there is a logical batch, and physical batch, please understand the two key logics there and you should be good to go. |
global_logical_batch_size = 16
global_physical_batch_size = 256 # You can also try 64 here to see if it works.
logical_feed_indices = list(range(0, 64, 4)) # 1 in 4 host read a single "real" batch. You don't need to specify other fields because the batcher can infer them. Also like Kelvin said, be sure to understand the meaning of these fields. |
Just sharing for now since it's related. I also hit this error when trying fsdp=16, mdoel=16 and gbs=128 on 256 chips:
|
@samos123 please understand the physical branch logic, you shouldn't do global bs=128 over 256 chips, it is always 256. Logical batch is something you care more about, and physical batch can be auto configured when we use logical, but we didn't have that in place yet. fwiw, you also cannot do model==16 is the num_heads, which is 8 for kv in GQA model. But that's a separate issue and you haven't hit that yet. |
fsdp=16 model=16 global_batch_size=16 should work on 256 chips
The use case is being able to use a global batch size smaller than total jax processes.
This is supported in maxtext by using this trick: https://github.com/AI-Hypercomputer/maxtext/blob/4cf51b7f204e109df502cf2d54b4d5005f597b09/MaxText/train.py#L289-L291
Trying to get 405b model running on v6e-256 (fsdp=16 model=16) but getting hit with this error:
The text was updated successfully, but these errors were encountered: