-
Notifications
You must be signed in to change notification settings - Fork 416
[JAX] Fix partitioning issues in LayerNorm and LayerNormMLP layers #1743
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[JAX] Fix partitioning issues in LayerNorm and LayerNormMLP layers #1743
Conversation
Signed-off-by: Jeremy Berchtold <[email protected]>
…hape checks before calling TE API Signed-off-by: Jeremy Berchtold <[email protected]>
5350e48
to
1260f07
Compare
/te-ci L0 |
NVTE_CHECK(act_len == 1 || act_len == 2, | ||
"The value of the activation dimension (axis=-2) must be 1 for non-gated or 2 for " | ||
"gated activation, got ", | ||
act_len); | ||
checkDActShapes(input_buf, act_input_buf, output_buf); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should be able to check all of these requirements in the Primitive, in the abstract() or/and lowering().
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, I guess checking in the abstract will enforce the full shape and the sharded shape as we call the abstract again on the inner sharded primitive
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved this check to the abstract function in python and confirmed it still catches the issue when the partitioning fix isn't applied
Signed-off-by: Jeremy Berchtold <[email protected]>
/te-ci L0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Description
Fixes two issues occurring in MaxText integration with TE. 1) NaNs caused on multi-gpu usage of the LayerNormMLP layer 2) Partitioning shape mismatch errors caused by using TE's LayerNorm layer with TP>1
Type of change
Changes
dz
andx
are the same shape.dz
always use the same partitioning asx
. In combination with the checks above, this prevents NaNs from running dact ondz
andx
partitioned differently leading to out of bounds reads.x
not partition along the TP axis. This prevents partitioning shape errors we were encountering wherex
was partitioned along TP but the norm output wasn't.Checklist: