Skip to content

Commit

Permalink
Add "selected_ops" transformer AC mode (#71)
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh authored Oct 22, 2024
1 parent d90292e commit 425f7db
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 5 deletions.
4 changes: 2 additions & 2 deletions .github/actions/setup-venv/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ inputs:
torch-version:
description: The PyTorch version to install
required: false
default: '==2.4.1'
default: '==2.5.0'
runs:
using: composite
steps:
Expand Down Expand Up @@ -44,7 +44,7 @@ runs:
# Set up virtual environment without cache hit.
test -d .venv || virtualenv -p $(which python) --copies --reset-app-data .venv
. .venv/bin/activate
pip install 'torch${{ inputs.torch-version }}' --extra-index-url https://download.pytorch.org/whl/cpu
pip install 'torch${{ inputs.torch-version }}' --index-url https://download.pytorch.org/whl/cpu
pip install -e .[all]
- if: steps.virtualenv-cache.outputs.cache-hit == 'true'
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Trainer now starts with a dry-run of a fake batch created by `DataLoaderBase.get_mock_batch()`.
- Added `Callback.pre_backward()`, `.pre_eval_batch()`, and `.post_eval_batch()` methods.
- Added `Trainer.model_forward()`, `.get_losses()`, and `.eval_batch()` methods.
- Added a new `TransformerActivationCheckpointingMode`, "selected_ops" (requires torch 2.5 or newer).

### Changed

Expand Down
16 changes: 13 additions & 3 deletions src/olmo_core/nn/transformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@
import torch
import torch.nn as nn
from torch.distributed import DeviceMesh
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper as ptd_checkpoint_wrapper,
)

from olmo_core.config import StrEnum
from olmo_core.data.utils import get_cumulative_document_lengths
Expand All @@ -16,6 +13,7 @@

from ..buffer_cache import BufferCache
from ..layer_norm import LayerNorm, LayerNormConfig
from ..utils import selective_checkpointing_context_fn
from .block import TransformerBlock, TransformerBlockConfig
from .init import InitMethod

Expand Down Expand Up @@ -56,6 +54,8 @@ class TransformerActivationCheckpointingMode(StrEnum):
"""Checkpoint only selected blocks."""
selected_modules = "selected_modules"
"""Checkpoint only selected modules."""
selected_ops = "selected_ops"
"""Checkpoint only a specific set of operations."""


class Transformer(nn.Module):
Expand Down Expand Up @@ -234,6 +234,10 @@ def apply_activation_checkpointing(
:param modules: Required when :data:`mode` is "selected_modules". A list of modules names
to wrap for activation checkpointing. Globs are supported.
"""
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper as ptd_checkpoint_wrapper,
)

if (
mode == TransformerActivationCheckpointingMode.selected_blocks
and block_interval is None
Expand Down Expand Up @@ -270,6 +274,12 @@ def apply_activation_checkpointing(
block = ptd_checkpoint_wrapper(block, preserve_rng_state=preserve_rng_state)
elif mode == TransformerActivationCheckpointingMode.full:
block = ptd_checkpoint_wrapper(block, preserve_rng_state=preserve_rng_state)
elif mode == TransformerActivationCheckpointingMode.selected_ops:
block = ptd_checkpoint_wrapper(
block,
context_fn=selective_checkpointing_context_fn,
preserve_rng_state=preserve_rng_state,
)

self.blocks.register_module(str(block_idx), block)

Expand Down
38 changes: 38 additions & 0 deletions src/olmo_core/nn/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from collections import defaultdict
from typing import Dict

import torch


def _get_custom_checkpoint_policy(meta: Dict[str, int]):
# Adapted from
# https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py
from torch.utils.checkpoint import CheckpointPolicy

_save_list = {
torch.ops.aten.mm.default, # type: ignore
torch.ops.aten._scaled_dot_product_efficient_attention.default, # type: ignore
torch.ops.aten._scaled_dot_product_flash_attention.default, # type: ignore
torch.ops._c10d_functional.reduce_scatter_tensor.default, # type: ignore
}

def _custom_policy(ctx, func, *args, **kwargs):
del args, kwargs
mode = "recompute" if ctx.is_recompute else "forward"
mm_count_key = f"{mode}_mm_count"
if func == torch.ops.aten.mm.default: # type: ignore
meta[mm_count_key] += 1
# Saves output of all compute ops, except every second mm
to_save = func in _save_list and not (
func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0 # type: ignore
)
return CheckpointPolicy.MUST_SAVE if to_save else CheckpointPolicy.PREFER_RECOMPUTE

return _custom_policy


def selective_checkpointing_context_fn():
from torch.utils.checkpoint import create_selective_checkpoint_contexts

meta: Dict[str, int] = defaultdict(int)
return create_selective_checkpoint_contexts(_get_custom_checkpoint_policy(meta))

0 comments on commit 425f7db

Please sign in to comment.