Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adds Granite Liger Kernel option for Granite3.y models #430

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

JamesKunstle
Copy link
Contributor

Support for Granite was added in LK v0.5.4, so we can add it as an additional performance option.

@mergify mergify bot added ci-failure dependencies Pull requests that update a dependency file labels Mar 21, 2025
@JamesKunstle JamesKunstle force-pushed the jkunstle/granite-liger-kernel branch 2 times, most recently from 1cd4e63 to 2b1ad16 Compare March 21, 2025 00:17
@mergify mergify bot added ci-failure and removed ci-failure labels Mar 21, 2025
@JamesKunstle JamesKunstle force-pushed the jkunstle/granite-liger-kernel branch from 2b1ad16 to e6ee97a Compare March 21, 2025 00:34
@JamesKunstle JamesKunstle force-pushed the jkunstle/granite-liger-kernel branch from e6ee97a to 89aeca5 Compare March 21, 2025 00:36
@mergify mergify bot added ci-failure and removed ci-failure labels Mar 21, 2025
@JamesKunstle
Copy link
Contributor Author

Current CI failure seems unrelated to liger_kernel addition

@RobotSail
Copy link
Member

Need to test but this could be very useful if it works correctly

@JamesKunstle
Copy link
Contributor Author

JamesKunstle commented Mar 22, 2025

Currently correctness test for the kernels themselves are done in-tree in the liger-kernels repo. They assure correct convergence and logit equivalence after training. Link to test here

I've also validated training-dynamics equivalence given identical batches. Link to Jira issue here.

There's a roughly 1% raw improvement with batches being identical. This is expected- the real benefits will come from larger possible batch sizes! This kernel set is also missing the most important kernel: LigerFusedLinearCrossEntropy which skips logit materialization, saving a lot of net memory headroom. We'll add this in the future.

I've given the PSAP team a ref to this PR so they can quantify the improved memory headroom and find new max_batch_lens

Copy link
Member

@RobotSail RobotSail left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JamesKunstle Thanks for adding this, a few comments

# Third Party
from liger_kernel.transformers import apply_liger_kernel_to_granite
except ImportError:
apply_liger_kernel_to_granite = lambda *args, **kwargs: None # pylint: disable=C3001
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JamesKunstle We'll want to support other models beyond Granite, and it seems like Liger's API has a number of different functions they expose for common architectures (mistral, llama, etc.).

It seems like they provide a way to map directly from the model_type field on the model into the Liger kernel to apply. Could we please use that here to support more models?

https://github.com/linkedin/Liger-Kernel/blob/293bf7eec7043c8c34b3cd82975c97e4c2f4254f/src/liger_kernel/transformers/monkey_patch.py#L1058C1-L1058C29

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

++ certainly, only limiting to Granite to get PSAP numbers then will expand the integration to other supported architectures.

@@ -974,6 +995,9 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
"The last checkpoint will be saved as 'last_epoch'."
),
)

# this will work for granite-3.y models but not granite-7b because that's a Llama 2 model arch.
parser.add_argument("--enable-granite-liger-kernel", action="store_true")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're providing this as an arg through the CLI, we should also expose this in the TrainingArgs config so the SDK can be consistent.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

++ agreed, will do once we've got the Granite numbers from PSAP

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci-failure dependencies Pull requests that update a dependency file
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants