-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
adds Granite Liger Kernel option for Granite3.y models #430
base: main
Are you sure you want to change the base?
adds Granite Liger Kernel option for Granite3.y models #430
Conversation
1cd4e63
to
2b1ad16
Compare
2b1ad16
to
e6ee97a
Compare
Signed-off-by: James Kunstle <[email protected]>
e6ee97a
to
89aeca5
Compare
Current CI failure seems unrelated to liger_kernel addition |
Need to test but this could be very useful if it works correctly |
Currently correctness test for the kernels themselves are done in-tree in the liger-kernels repo. They assure correct convergence and logit equivalence after training. Link to test here I've also validated training-dynamics equivalence given identical batches. Link to Jira issue here. There's a roughly 1% raw improvement with batches being identical. This is expected- the real benefits will come from larger possible batch sizes! This kernel set is also missing the most important kernel: I've given the PSAP team a ref to this PR so they can quantify the improved memory headroom and find new max_batch_lens |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JamesKunstle Thanks for adding this, a few comments
# Third Party | ||
from liger_kernel.transformers import apply_liger_kernel_to_granite | ||
except ImportError: | ||
apply_liger_kernel_to_granite = lambda *args, **kwargs: None # pylint: disable=C3001 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JamesKunstle We'll want to support other models beyond Granite, and it seems like Liger's API has a number of different functions they expose for common architectures (mistral, llama, etc.).
It seems like they provide a way to map directly from the model_type
field on the model into the Liger kernel to apply. Could we please use that here to support more models?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
++ certainly, only limiting to Granite to get PSAP numbers then will expand the integration to other supported architectures.
@@ -974,6 +995,9 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: | |||
"The last checkpoint will be saved as 'last_epoch'." | |||
), | |||
) | |||
|
|||
# this will work for granite-3.y models but not granite-7b because that's a Llama 2 model arch. | |||
parser.add_argument("--enable-granite-liger-kernel", action="store_true") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we're providing this as an arg through the CLI, we should also expose this in the TrainingArgs
config so the SDK can be consistent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
++ agreed, will do once we've got the Granite numbers from PSAP
Support for Granite was added in LK v0.5.4, so we can add it as an additional performance option.