Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] Sequential onloading #1263

Open
wants to merge 15 commits into
base: main
Choose a base branch
from

Conversation

kylesayrs
Copy link
Collaborator

@kylesayrs kylesayrs commented Mar 18, 2025

Purpose

  • Reduce hardware requirements when calibrating large models by only onloading one layer at a time when calibrating using the sequential pipeline
  • Updating the examples can be done after pipeline extraction lands. Examples which only use the basic pipeline should dispatch to "auto", while examples which use GPTQ should dispatch to the cpu and set oneshot_device.

Usage

When using the sequential pipeline, a few behaviors change

  • If your model is dispatched to the gpu (has parameters which execute on a gpu), then a warning is raised
logger.warning(
    "Calibrating a model dispatched to the gpu can potentially lead to OOM "
    "errors. Consider loading the model without a `device_map` and instead "
    "executing with `cuda:0` (set `oneshot_device` to override this default)"
)
  • Otherwise (if you model is dispatched to the cpu), then the oneshot_device argument is used to determine the onload device (this defaults to cuda if a cuda device is available)
elif oneshot_device is None:
    has_cuda = torch.cuda.is_available()
    oneshot_device = torch.device("cuda:0") if has_cuda else torch.device("cpu")
    logger.info(f"No oneshot_device passed, using {oneshot_device}")

This policy encourages users to dispatch to the CPU when using the sequential pipeline, and to dispatch to "auto" when using the basic pipeline

Changes

  • Keep layer parameters onloaded during the entire sequential calibration + compression step

Testing

  • Calibrated and GPTQ-compressed one layer of Deepseek-V3 with a single H100 in 50 seconds
    • 4.5x Improvement over original 236 seconds
    • Peak memory of ~40 GB, which can be further reduced by increasing the granularity of sequential targets
  • Not offloading activations did not result in a performance improvement

Copy link

👋 Hi! Thank you for contributing to llm-compressor. Please add the ready label when the PR is ready for review.

Note: This is required to complete the testing suite, please only add the label once the PR is code complete and local testing has been performed.

@kylesayrs kylesayrs added the ready When a PR is ready for review label Mar 18, 2025
@kylesayrs kylesayrs self-assigned this Mar 18, 2025
@brian-dellabetta brian-dellabetta self-requested a review March 18, 2025 14:17


@contextlib.contextmanager
def align_modules(modules: Iterable[torch.nn.Module]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not keep this in compressed tensors with the other cpu offloading tools?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep! Implementing here before the next CT release

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# and is only used for capturing outputs from the newly compressed modules
with HooksMixin.disable_hooks():
for batch_index in tqdm.tqdm(range(len(dataloader)), desc=prop_desc):
with align_modules([layer]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like all we're doing is wrapping the forward passes in this context manager, if I'm reading this correctly?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Rather than onloading then discarding for each of the 512 forward passes, we onload once for the layer and keep it onloaded through compression and propagation.

@@ -310,11 +313,13 @@ def partition_graph(model: Module, partitions: List[List[Node]]) -> List[Subgrap
# save the subgraph for this partition
graph.lint()
input_names = set(node.name for node in graph.nodes if node.op == "placeholder")
modules = get_subgraph_modules(graph, parent_graph)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mind explaining what we're changing in our graph partition here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The graph partition doesn't change, this change just collects all the modules used by this subgraph for use in onloading/offloading by the sequential pipeline.

Copy link
Collaborator

@brian-dellabetta brian-dellabetta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, i approved this thinking it was the one-liner removing clear-ml, will have to take a closer look

@brian-dellabetta brian-dellabetta dismissed their stale review March 18, 2025 14:20

sorry, i approved this thinking it was the one-liner removing clear-ml, will have to take a closer look

Copy link
Collaborator

@brian-dellabetta brian-dellabetta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am understanding this for the most part -- very cool!

dsikka and others added 10 commits March 27, 2025 13:52
SUMMARY:
- Remove requirement for tokens and the one test which uses them

Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Co-authored-by: Brian Dellabetta <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
@kylesayrs kylesayrs force-pushed the kylesayrs/sequential-onloading branch from cf09876 to 72e7683 Compare March 27, 2025 17:53
Signed-off-by: Kyle Sayers <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready When a PR is ready for review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants