Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for MoE models #60

Merged
merged 61 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
f5a96e0
add placeholder for MoE module
epwalsh Oct 2, 2024
936b6e8
start adding MoE components
epwalsh Oct 7, 2024
12f40a4
fix megablocks commit SHA
epwalsh Oct 8, 2024
346dc62
clean up
epwalsh Oct 8, 2024
4527d34
fix
epwalsh Oct 8, 2024
ee44f7d
updates
epwalsh Oct 9, 2024
f9b08bb
pin numpy
epwalsh Oct 9, 2024
3fe143c
updates
epwalsh Oct 9, 2024
4095321
Merge branch 'main' into epwalsh/moe
epwalsh Oct 10, 2024
217fa7e
pin to fork of megablocks
epwalsh Oct 10, 2024
31f6db1
fix docker build
epwalsh Oct 10, 2024
3128526
show image size
epwalsh Oct 10, 2024
8ecfdef
fix
epwalsh Oct 10, 2024
85bed9b
clear up some disk space
epwalsh Oct 10, 2024
0366d7c
fix
epwalsh Oct 10, 2024
47785d6
use larger runner
epwalsh Oct 10, 2024
43dab58
updates
epwalsh Oct 10, 2024
8f00ac8
Build separate dev image
epwalsh Oct 10, 2024
0d3e025
use dev image for testing
epwalsh Oct 10, 2024
a195ee5
fix other image builds
epwalsh Oct 10, 2024
f539821
revert nightly version
epwalsh Oct 10, 2024
ebc82ef
updates
epwalsh Oct 10, 2024
7a34742
Add MoE test
epwalsh Oct 10, 2024
8879118
Adjust tests
epwalsh Oct 10, 2024
4d1b5d8
fixes
epwalsh Oct 11, 2024
ba95e86
fix
epwalsh Oct 11, 2024
5fe437a
increase dims
epwalsh Oct 11, 2024
e7e75cb
comment
epwalsh Oct 11, 2024
d5bc650
Merge branch 'main' into epwalsh/moe
epwalsh Oct 11, 2024
c3176a4
add 9.0 arch list
epwalsh Oct 11, 2024
5a0432b
fix
epwalsh Oct 11, 2024
aefc85f
Use wrapper around megablocks classes
epwalsh Oct 11, 2024
f3b485d
Fixes, add callback to collect loss
epwalsh Oct 11, 2024
97a6ab5
fix test
epwalsh Oct 11, 2024
6e77e4b
add training script for OLMoE-1B-7B
epwalsh Oct 11, 2024
738eed6
fix init
epwalsh Oct 11, 2024
6105be1
fix
epwalsh Oct 11, 2024
881f6e1
another fix
epwalsh Oct 11, 2024
a5840de
decrease mbz
epwalsh Oct 11, 2024
baa0f89
fix?
epwalsh Oct 11, 2024
2ffa147
fix
epwalsh Oct 11, 2024
c8bd786
Fix Z-loss calculation
epwalsh Oct 12, 2024
8975115
Load MoE losses separately
epwalsh Oct 12, 2024
8d65d68
clean up API
epwalsh Oct 12, 2024
b8c84c4
fix
epwalsh Oct 12, 2024
912dc06
test different MoE MLP implementations
epwalsh Oct 12, 2024
12f13a2
Another fix to loss buffer reset
epwalsh Oct 12, 2024
68f2c79
Set test input dtype to BF16
epwalsh Oct 12, 2024
81541c6
Set dtype correctly
epwalsh Oct 12, 2024
f5706a9
improve how test dtype is logged
epwalsh Oct 12, 2024
046560f
specify when docker builds should run
epwalsh Oct 12, 2024
2394ba7
try compile again
epwalsh Oct 14, 2024
f58430f
improve image builds
epwalsh Oct 14, 2024
edcadb3
Add memory-optimized variant
epwalsh Oct 14, 2024
a1ba50e
Add fine-grained wrapping strategy
epwalsh Oct 14, 2024
a189795
fix
epwalsh Oct 14, 2024
b949489
fix example
epwalsh Oct 14, 2024
ac09e91
fix
epwalsh Oct 14, 2024
1af7c8e
Merge branch 'main' into epwalsh/moe
epwalsh Oct 17, 2024
f2ed966
Explicitly mark beta features
epwalsh Oct 18, 2024
f44a110
Update README.md
epwalsh Oct 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 24 additions & 11 deletions .github/workflows/docker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ on:
pull_request:
branches:
- main
paths:
- 'Makefile'
- 'pyproject.toml'
- 'src/olmo_core/version.py'
- 'src/Dockerfile'
- '.github/workflows/docker.yml'
push:
branches:
- main
Expand All @@ -16,15 +22,11 @@ on:

jobs:
beaker:
name: Beaker image (${{ matrix.version }})
runs-on: ubuntu-latest
timeout-minutes: 20
name: Beaker images
runs-on: ubuntu-latest-m
timeout-minutes: 60
env:
BEAKER_TOKEN: ${{ secrets.BEAKER_TOKEN }}
strategy:
fail-fast: false
matrix:
version: [nightly, stable]
steps:
- uses: actions/checkout@v3

Expand All @@ -35,17 +37,28 @@ jobs:
run: |
echo "BEAKER_WORKSPACE=$(make get-beaker-workspace)" >> $GITHUB_ENV

- name: Build
- name: Build stable image
run: |
make ${{ matrix.version }}-image
make stable-image

- name: Build nightly image
run: |
make nightly-image

- uses: allenai/setup-beaker@v2
if: env.BEAKER_TOKEN != ''
with:
token: ${{ env.BEAKER_TOKEN }}
workspace: ${{ env.BEAKER_WORKSPACE }}

- name: Push
- name: Push stable image
if: env.BEAKER_TOKEN != '' && startsWith(github.ref, 'refs/tags/')
run: |
rm -rf /opt/hostedtoolcache # clear up some disk space
make beaker-image-stable

- name: Push nightly image
if: env.BEAKER_TOKEN != '' && startsWith(github.ref, 'refs/tags/')
run: |
make beaker-image-${{ matrix.version }}
rm -rf /opt/hostedtoolcache # clear up some disk space
make beaker-image-nightly
25 changes: 21 additions & 4 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,27 @@ jobs:
matrix:
task:
- name: Test (GPU)
run: pytest -v --color=yes --durations=3 -m gpu src/test/ --ignore-glob='src/test/distributed/checkpoint*'
image: olmo-core
gpus: 2
run: |
pytest -v --color=yes --durations=3 -m gpu \
--ignore-glob='src/test/distributed/checkpoint*' \
--ignore-glob='src/test/nn/moe*' \
src/test/

- name: Test checkpoint (GPU)
run: pytest -v --color=yes --durations=3 -m gpu src/test/distributed/checkpoint*
image: olmo-core
gpus: 2
run: |
pytest -v --color=yes --durations=3 -m gpu \
src/test/distributed/checkpoint*

- name: Test MoE (GPU)
image: olmo-core-nightly
gpus: 1
run: |
pytest -v --color=yes --durations=3 -m gpu \
src/test/nn/moe*
steps:
- uses: actions/checkout@v3

Expand Down Expand Up @@ -142,7 +159,7 @@ jobs:
- name: Get full image name
if: env.BEAKER_TOKEN != ''
run:
echo "BEAKER_IMAGE=$(make get-full-beaker-image-name)" >> $GITHUB_ENV
echo "BEAKER_IMAGE=$(make get-full-beaker-image-name IMAGE_NAME=${{ matrix.task.image }})" >> $GITHUB_ENV

- name: GPU Tests
uses: allenai/[email protected]
Expand All @@ -160,7 +177,7 @@ jobs:
priority: low
preemptible: true
resources:
gpuCount: 2
gpuCount: ${{ matrix.task.gpus }}
constraints:
cluster:
- ai2/allennlp-cirrascale
Expand Down
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added Google Cloud support for `list_directory()` and `clear_directory()`.
- Added `CometCallback` for logging training runs to Comet.ml.
- Added `DataMixBase` class, to allow extending to new data mix groups.
- Added support for MoE-based models.
- Added method `DataLoaderBase.get_mock_batch()`.
- Trainer now starts with a dry-run of a fake batch created by `DataLoaderBase.get_mock_batch()`.
- Added `Callback.pre_backward()`, `.pre_eval_batch()`, and `.post_eval_batch()` methods.
- Added `Trainer.model_forward()`, `.get_losses()`, and `.eval_batch()` methods.

### Changed

Expand Down
27 changes: 19 additions & 8 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
BASE_IMAGE = ghcr.io/allenai/pytorch:2.4.1-cuda12.1-python3.11

# NOTE: when upgrading the nightly version you also need to upgrade the torch version specification
# in 'pyproject.toml' to include that nightly version.
NIGHTLY_VERSION = "2.6.0.dev20241009+cu121 --index-url https://download.pytorch.org/whl/nightly/cu121"
TORCHAO_VERSION = "0.5.0 --extra-index-url https://download.pytorch.org/whl/cu121"
TORCHAO_VERSION = "torchao==0.5.0 --extra-index-url https://download.pytorch.org/whl/cu121"
MEGABLOCKS_VERSION = "megablocks[gg] @ git+https://[email protected]/epwalsh/megablocks.git@epwalsh/deps"
CUDA_TOOLKIT_VERSION = 12.1.0

VERSION = $(shell python src/olmo_core/version.py)
VERSION_SHORT = $(shell python src/olmo_core/version.py short)
Expand Down Expand Up @@ -45,25 +48,33 @@ stable-image :
docker build -f src/Dockerfile \
--build-arg BUILDKIT_INLINE_CACHE=1 \
--build-arg BASE=$(BASE_IMAGE) \
--build-arg CUDA_TOOLKIT_VERSION=$(CUDA_TOOLKIT_VERSION) \
--build-arg MEGABLOCKS_VERSION=$(MEGABLOCKS_VERSION) \
--build-arg TORCHAO_VERSION=$(TORCHAO_VERSION) \
--target stable \
--progress plain \
-t $(IMAGE_BASENAME) .

.PHONY : beaker-image-stable
beaker-image-stable : stable-image
./src/scripts/beaker/create_beaker_image.sh $(IMAGE_BASENAME) $(IMAGE_BASENAME) $(BEAKER_WORKSPACE)
./src/scripts/beaker/create_beaker_image.sh $(IMAGE_BASENAME) $(IMAGE_BASENAME)-v$(VERSION_SHORT) $(BEAKER_WORKSPACE)
./src/scripts/beaker/create_beaker_image.sh $(IMAGE_BASENAME) $(IMAGE_BASENAME)-v$(VERSION) $(BEAKER_WORKSPACE)
echo "Built image '$(IMAGE_BASENAME)', size: $$(docker inspect -f '{{ .Size }}' $(IMAGE_BASENAME) | numfmt --to=si)"

.PHONY : nightly-image
nightly-image :
docker build -f src/Dockerfile \
--build-arg BUILDKIT_INLINE_CACHE=1 \
--build-arg BASE=$(BASE_IMAGE) \
--build-arg CUDA_TOOLKIT_VERSION=$(CUDA_TOOLKIT_VERSION) \
--build-arg MEGABLOCKS_VERSION=$(MEGABLOCKS_VERSION) \
--build-arg TORCHAO_VERSION=$(TORCHAO_VERSION) \
--build-arg NIGHTLY_VERSION=$(NIGHTLY_VERSION) \
--target nightly \
--progress plain \
-t $(IMAGE_BASENAME)-nightly .
echo "Built image '$(IMAGE_BASENAME)-nightly', size: $$(docker inspect -f '{{ .Size }}' $(IMAGE_BASENAME)-nightly | numfmt --to=si)"

.PHONY : beaker-image-stable
beaker-image-stable : stable-image
./src/scripts/beaker/create_beaker_image.sh $(IMAGE_BASENAME) $(IMAGE_BASENAME) $(BEAKER_WORKSPACE)
./src/scripts/beaker/create_beaker_image.sh $(IMAGE_BASENAME) $(IMAGE_BASENAME)-v$(VERSION_SHORT) $(BEAKER_WORKSPACE)
./src/scripts/beaker/create_beaker_image.sh $(IMAGE_BASENAME) $(IMAGE_BASENAME)-v$(VERSION) $(BEAKER_WORKSPACE)

.PHONY : beaker-image-nightly
beaker-image-nightly : nightly-image
Expand All @@ -77,4 +88,4 @@ get-beaker-workspace :

.PHONY : get-full-beaker-image-name
get-full-beaker-image-name :
@./src/scripts/beaker/get_full_image_name.sh $(IMAGE_BASENAME) $(BEAKER_WORKSPACE)
@./src/scripts/beaker/get_full_image_name.sh $(IMAGE_NAME) $(BEAKER_WORKSPACE)
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ First install [PyTorch](https://pytorch.org) according to the instructions speci
pip install ai2-olmo-core
```

## API stability

Even though this library is under rapid development we are trying hard to adhere to [Semantic Versioning](https://semver.org/spec/v2.0.0.html) with every release except for features that are explicitly marked as beta features. Those features will be tagged like this in the [API docs](https://olmo-core.readthedocs.io/en/latest/):

![image](https://github.com/user-attachments/assets/c666686d-3ae6-4c88-8381-befd698d3fd0)

## Official training scripts

Official training scripts for various model sizes can be found in [`src/scripts/train/`](https://github.com/allenai/OLMo-core/tree/main/src/scripts/train).
Expand Down
6 changes: 6 additions & 0 deletions docs/source/nn/attention.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
``nn.attention``
================

.. automodule:: olmo_core.nn.attention
:members:
:member-order: bysource
6 changes: 6 additions & 0 deletions docs/source/nn/feed_forward.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
``nn.feed_forward``
===================

.. automodule:: olmo_core.nn.feed_forward
:members:
:member-order: bysource
34 changes: 5 additions & 29 deletions docs/source/nn/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,38 +3,14 @@

.. automodule:: olmo_core.nn

Attention
---------

.. automodule:: olmo_core.nn.attention
:members:
:member-order: bysource

FeedForward
-----------

.. automodule:: olmo_core.nn.feed_forward
:members:
:member-order: bysource

RoPE
----

.. automodule:: olmo_core.nn.rope
:members:
:member-order: bysource

LayerNorms
----------

.. automodule:: olmo_core.nn.layer_norm
:members:
:member-order: bysource

.. toctree::
:maxdepth: 2
:caption: Submodules
:hidden:

attention
feed_forward
functional
layer_norm
moe
rope
transformer
6 changes: 6 additions & 0 deletions docs/source/nn/layer_norm.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
``nn.layer_norm``
=================

.. automodule:: olmo_core.nn.layer_norm
:members:
:member-order: bysource
6 changes: 6 additions & 0 deletions docs/source/nn/moe.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
``nn.moe``
==========

.. automodule:: olmo_core.nn.moe
:members:
:member-order: bysource
6 changes: 6 additions & 0 deletions docs/source/nn/rope.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
``nn.rope``
===========

.. automodule:: olmo_core.nn.rope
:members:
:member-order: bysource
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ authors = [
requires-python = ">=3.9"
license = { file = "LICENSE" }
dependencies = [
"numpy",
"numpy<2.0",
"torch>=2.4,<=2.6.0.dev20241009",
"cached-path",
"requests",
Expand Down
38 changes: 35 additions & 3 deletions src/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,12 +1,40 @@
# Base image comes with PyTorch, numpy, flash-attn
ARG BASE

#########################################################################
# Build image
#########################################################################

FROM ${BASE} as build

WORKDIR /app/build

# Install CUDA toolkit.
ARG CUDA_TOOLKIT_VERSION
RUN conda install -y -c nvidia cuda-toolkit==${CUDA_TOOLKIT_VERSION}

# Build megablocks and grouped-gemm.
ENV TORCH_CUDA_ARCH_LIST="8.0 9.0"
ENV GROUPED_GEMM_CUTLASS=1
ARG MEGABLOCKS_VERSION
RUN pip wheel --no-build-isolation --no-cache-dir "${MEGABLOCKS_VERSION}" \
&& rm -rf torch-*.whl numpy-*.whl triton-*.whl

#########################################################################
# Stable image
#########################################################################

FROM ${BASE} as stable

# Install torchao
# Install torchao.
ARG TORCHAO_VERSION
RUN pip install --no-cache-dir torchao==${TORCHAO_VERSION}
RUN pip install --no-cache-dir ${TORCHAO_VERSION}

# Install other dependencies, but not the source code.
# Copy and install wheels from build image.
COPY --from=build /app/build /app/build
RUN pip install --no-cache-dir /app/build/*

# Install direct dependencies, but not source code.
COPY pyproject.toml .
COPY src/olmo_core/__init__.py src/olmo_core/__init__.py
COPY src/olmo_core/version.py src/olmo_core/version.py
Expand All @@ -16,6 +44,10 @@ RUN pip install --no-cache-dir '.[all]' && \

WORKDIR /app/olmo-core

#########################################################################
# Nightly image
#########################################################################

FROM stable as nightly

ARG NIGHTLY_VERSION
Expand Down
6 changes: 3 additions & 3 deletions src/examples/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
NumpyDatasetType,
TokenizerConfig,
)
from olmo_core.distributed.parallel import DataParallelConfig, DataParallelType
from olmo_core.distributed.parallel import DataParallelType
from olmo_core.distributed.utils import init_hybrid_shard_mesh
from olmo_core.nn.transformer import TransformerConfig
from olmo_core.nn.transformer import TransformerConfig, TransformerDataParallelConfig
from olmo_core.optim import AdamWConfig, CosWithWarmup, OptimGroupOverride
from olmo_core.train import (
Duration,
Expand Down Expand Up @@ -58,7 +58,7 @@ def build_config(run_name: str, overrides: List[str]) -> ExperimentConfig:
model_config = TransformerConfig.llama2_271M(
vocab_size=tokenizer_config.padded_vocab_size(), # a little bigger than actual vocab size to make it a multiple of 128
compile=True,
dp_config=DataParallelConfig(
dp_config=TransformerDataParallelConfig(
name=DataParallelType.fsdp, param_dtype=DType.bfloat16, reduce_dtype=DType.float32
),
)
Expand Down
9 changes: 9 additions & 0 deletions src/olmo_core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,5 +217,14 @@ class DType(StrEnum):
float32 = "float32"
bfloat16 = "bfloat16"

@classmethod
def from_pt(cls, dtype: torch.dtype) -> "DType":
if dtype == torch.float32:
return DType.float32
elif dtype == torch.bfloat16:
return DType.bfloat16
else:
raise NotImplementedError(dtype)

def as_pt(self) -> torch.dtype:
return getattr(torch, self)
22 changes: 22 additions & 0 deletions src/olmo_core/doc_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import TypeVar

T = TypeVar("T")


def beta_feature(f: T) -> T:
"""
Mark a class or function as a beta feature.
"""
if f.__doc__ is None:
f.__doc__ = ""

f.__doc__ += """

.. warning::
This is a beta feature! The API is subject to change even with minor and patch releases.
If you choose to use this feature please read the `CHANGELOG <https://github.com/allenai/OLMo-core/blob/main/CHANGELOG.md>`_
before upgrading your version of this library.

"""

return f
Loading
Loading