Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for MoE models #60

Merged
merged 61 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from 41 commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
f5a96e0
add placeholder for MoE module
epwalsh Oct 2, 2024
936b6e8
start adding MoE components
epwalsh Oct 7, 2024
12f40a4
fix megablocks commit SHA
epwalsh Oct 8, 2024
346dc62
clean up
epwalsh Oct 8, 2024
4527d34
fix
epwalsh Oct 8, 2024
ee44f7d
updates
epwalsh Oct 9, 2024
f9b08bb
pin numpy
epwalsh Oct 9, 2024
3fe143c
updates
epwalsh Oct 9, 2024
4095321
Merge branch 'main' into epwalsh/moe
epwalsh Oct 10, 2024
217fa7e
pin to fork of megablocks
epwalsh Oct 10, 2024
31f6db1
fix docker build
epwalsh Oct 10, 2024
3128526
show image size
epwalsh Oct 10, 2024
8ecfdef
fix
epwalsh Oct 10, 2024
85bed9b
clear up some disk space
epwalsh Oct 10, 2024
0366d7c
fix
epwalsh Oct 10, 2024
47785d6
use larger runner
epwalsh Oct 10, 2024
43dab58
updates
epwalsh Oct 10, 2024
8f00ac8
Build separate dev image
epwalsh Oct 10, 2024
0d3e025
use dev image for testing
epwalsh Oct 10, 2024
a195ee5
fix other image builds
epwalsh Oct 10, 2024
f539821
revert nightly version
epwalsh Oct 10, 2024
ebc82ef
updates
epwalsh Oct 10, 2024
7a34742
Add MoE test
epwalsh Oct 10, 2024
8879118
Adjust tests
epwalsh Oct 10, 2024
4d1b5d8
fixes
epwalsh Oct 11, 2024
ba95e86
fix
epwalsh Oct 11, 2024
5fe437a
increase dims
epwalsh Oct 11, 2024
e7e75cb
comment
epwalsh Oct 11, 2024
d5bc650
Merge branch 'main' into epwalsh/moe
epwalsh Oct 11, 2024
c3176a4
add 9.0 arch list
epwalsh Oct 11, 2024
5a0432b
fix
epwalsh Oct 11, 2024
aefc85f
Use wrapper around megablocks classes
epwalsh Oct 11, 2024
f3b485d
Fixes, add callback to collect loss
epwalsh Oct 11, 2024
97a6ab5
fix test
epwalsh Oct 11, 2024
6e77e4b
add training script for OLMoE-1B-7B
epwalsh Oct 11, 2024
738eed6
fix init
epwalsh Oct 11, 2024
6105be1
fix
epwalsh Oct 11, 2024
881f6e1
another fix
epwalsh Oct 11, 2024
a5840de
decrease mbz
epwalsh Oct 11, 2024
baa0f89
fix?
epwalsh Oct 11, 2024
2ffa147
fix
epwalsh Oct 11, 2024
c8bd786
Fix Z-loss calculation
epwalsh Oct 12, 2024
8975115
Load MoE losses separately
epwalsh Oct 12, 2024
8d65d68
clean up API
epwalsh Oct 12, 2024
b8c84c4
fix
epwalsh Oct 12, 2024
912dc06
test different MoE MLP implementations
epwalsh Oct 12, 2024
12f13a2
Another fix to loss buffer reset
epwalsh Oct 12, 2024
68f2c79
Set test input dtype to BF16
epwalsh Oct 12, 2024
81541c6
Set dtype correctly
epwalsh Oct 12, 2024
f5706a9
improve how test dtype is logged
epwalsh Oct 12, 2024
046560f
specify when docker builds should run
epwalsh Oct 12, 2024
2394ba7
try compile again
epwalsh Oct 14, 2024
f58430f
improve image builds
epwalsh Oct 14, 2024
edcadb3
Add memory-optimized variant
epwalsh Oct 14, 2024
a1ba50e
Add fine-grained wrapping strategy
epwalsh Oct 14, 2024
a189795
fix
epwalsh Oct 14, 2024
b949489
fix example
epwalsh Oct 14, 2024
ac09e91
fix
epwalsh Oct 14, 2024
1af7c8e
Merge branch 'main' into epwalsh/moe
epwalsh Oct 17, 2024
f2ed966
Explicitly mark beta features
epwalsh Oct 18, 2024
f44a110
Update README.md
epwalsh Oct 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions .github/workflows/docker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,23 @@ on:

jobs:
beaker:
name: Beaker image (${{ matrix.version }})
runs-on: ubuntu-latest
timeout-minutes: 20
name: Beaker image (${{ matrix.build.version }})
runs-on: ${{ matrix.build.runs-on }}
timeout-minutes: 60
env:
BEAKER_TOKEN: ${{ secrets.BEAKER_TOKEN }}
strategy:
fail-fast: false
matrix:
version: [nightly, stable]
build:
- version: stable
runs-on: ubuntu-latest

- version: nightly
runs-on: ubuntu-latest

- version: dev
runs-on: ubuntu-latest-m # requires a larger instance
steps:
- uses: actions/checkout@v3

Expand All @@ -37,7 +45,7 @@ jobs:

- name: Build
run: |
make ${{ matrix.version }}-image
make ${{ matrix.build.version }}-image

- uses: allenai/setup-beaker@v2
if: env.BEAKER_TOKEN != ''
Expand All @@ -48,4 +56,5 @@ jobs:
- name: Push
if: env.BEAKER_TOKEN != '' && startsWith(github.ref, 'refs/tags/')
run: |
make beaker-image-${{ matrix.version }}
rm -rf /opt/hostedtoolcache # clear up some disk space
make beaker-image-${{ matrix.build.version }}
25 changes: 21 additions & 4 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,27 @@ jobs:
matrix:
task:
- name: Test (GPU)
run: pytest -v --color=yes --durations=3 -m gpu src/test/ --ignore-glob='src/test/distributed/checkpoint*'
image: olmo-core
gpus: 2
run: |
pytest -v --color=yes --durations=3 -m gpu \
--ignore-glob='src/test/distributed/checkpoint*' \
--ignore-glob='src/test/nn/moe*' \
src/test/

- name: Test checkpoint (GPU)
run: pytest -v --color=yes --durations=3 -m gpu src/test/distributed/checkpoint*
image: olmo-core
gpus: 2
run: |
pytest -v --color=yes --durations=3 -m gpu \
src/test/distributed/checkpoint*

- name: Test MoE (GPU)
image: olmo-core-dev
gpus: 1
run: |
pytest -v --color=yes --durations=3 -m gpu \
src/test/nn/moe*
steps:
- uses: actions/checkout@v3

Expand Down Expand Up @@ -142,7 +159,7 @@ jobs:
- name: Get full image name
if: env.BEAKER_TOKEN != ''
run:
echo "BEAKER_IMAGE=$(make get-full-beaker-image-name)" >> $GITHUB_ENV
echo "BEAKER_IMAGE=$(make get-full-beaker-image-name IMAGE_NAME=${{ matrix.task.image }})" >> $GITHUB_ENV

- name: GPU Tests
uses: allenai/[email protected]
Expand All @@ -160,7 +177,7 @@ jobs:
priority: low
preemptible: true
resources:
gpuCount: 2
gpuCount: ${{ matrix.task.gpus }}
constraints:
cluster:
- ai2/allennlp-cirrascale
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added `CometCallback` for logging training runs to Comet.ml.
- Added `DataMixBase` class, to allow extending to new data mix groups.
- Added support for MoE-based models.
- Added callback method `Callback.post_model_forward()`.
- Added method `DataLoaderBase.get_mock_batch()`.
- Trainer now starts with a dry-run of a fake batch created by `DataLoaderBase.get_mock_batch()`.

Expand Down
43 changes: 35 additions & 8 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
BASE_IMAGE = ghcr.io/allenai/pytorch:2.4.1-cuda12.1-python3.11
DEV_BASE_IMAGE = ghcr.io/allenai/pytorch:2.4.1-cuda12.1-python3.11-dev

# NOTE: when upgrading the nightly version you also need to upgrade the torch version specification
# in 'pyproject.toml' to include that nightly version.
NIGHTLY_VERSION = "2.6.0.dev20241009+cu121 --index-url https://download.pytorch.org/whl/nightly/cu121"
TORCHAO_VERSION = "0.5.0 --extra-index-url https://download.pytorch.org/whl/cu121"
TORCHAO_VERSION = "torchao==0.5.0 --extra-index-url https://download.pytorch.org/whl/cu121"
MEGABLOCKS_VERSION = "megablocks[gg] @ git+https://[email protected]/epwalsh/megablocks.git@epwalsh/deps"

VERSION = $(shell python src/olmo_core/version.py)
VERSION_SHORT = $(shell python src/olmo_core/version.py short)
Expand Down Expand Up @@ -46,14 +49,11 @@ stable-image :
--build-arg BUILDKIT_INLINE_CACHE=1 \
--build-arg BASE=$(BASE_IMAGE) \
--build-arg TORCHAO_VERSION=$(TORCHAO_VERSION) \
--build-arg MEGABLOCKS_VERSION=$(MEGABLOCKS_VERSION) \
--target stable \
--progress plain \
-t $(IMAGE_BASENAME) .

.PHONY : beaker-image-stable
beaker-image-stable : stable-image
./src/scripts/beaker/create_beaker_image.sh $(IMAGE_BASENAME) $(IMAGE_BASENAME) $(BEAKER_WORKSPACE)
./src/scripts/beaker/create_beaker_image.sh $(IMAGE_BASENAME) $(IMAGE_BASENAME)-v$(VERSION_SHORT) $(BEAKER_WORKSPACE)
./src/scripts/beaker/create_beaker_image.sh $(IMAGE_BASENAME) $(IMAGE_BASENAME)-v$(VERSION) $(BEAKER_WORKSPACE)
echo "Built image '$(IMAGE_BASENAME)', size: $$(docker inspect -f '{{ .Size }}' $(IMAGE_BASENAME) | numfmt --to=si)"

.PHONY : nightly-image
nightly-image :
Expand All @@ -63,18 +63,45 @@ nightly-image :
--build-arg TORCHAO_VERSION=$(TORCHAO_VERSION) \
--build-arg NIGHTLY_VERSION=$(NIGHTLY_VERSION) \
--target nightly \
--progress plain \
-t $(IMAGE_BASENAME)-nightly .
echo "Built image '$(IMAGE_BASENAME)-nightly', size: $$(docker inspect -f '{{ .Size }}' $(IMAGE_BASENAME)-nightly | numfmt --to=si)"

.PHONY : dev-image
dev-image :
docker build -f src/Dockerfile \
--build-arg BUILDKIT_INLINE_CACHE=1 \
--build-arg BASE=$(DEV_BASE_IMAGE) \
--build-arg TORCHAO_VERSION=$(TORCHAO_VERSION) \
--build-arg MEGABLOCKS_VERSION=$(MEGABLOCKS_VERSION) \
--build-arg NIGHTLY_VERSION=$(NIGHTLY_VERSION) \
--target dev \
--progress plain \
-t $(IMAGE_BASENAME)-dev .
echo "Built image '$(IMAGE_BASENAME)-dev', size: $$(docker inspect -f '{{ .Size }}' $(IMAGE_BASENAME)-nightly | numfmt --to=si)"

.PHONY : beaker-image-stable
beaker-image-stable : stable-image
./src/scripts/beaker/create_beaker_image.sh $(IMAGE_BASENAME) $(IMAGE_BASENAME) $(BEAKER_WORKSPACE)
./src/scripts/beaker/create_beaker_image.sh $(IMAGE_BASENAME) $(IMAGE_BASENAME)-v$(VERSION_SHORT) $(BEAKER_WORKSPACE)
./src/scripts/beaker/create_beaker_image.sh $(IMAGE_BASENAME) $(IMAGE_BASENAME)-v$(VERSION) $(BEAKER_WORKSPACE)

.PHONY : beaker-image-nightly
beaker-image-nightly : nightly-image
./src/scripts/beaker/create_beaker_image.sh $(IMAGE_BASENAME)-nightly $(IMAGE_BASENAME)-nightly $(BEAKER_WORKSPACE)
./src/scripts/beaker/create_beaker_image.sh $(IMAGE_BASENAME)-nightly $(IMAGE_BASENAME)-v$(VERSION_SHORT)-nightly $(BEAKER_WORKSPACE)
./src/scripts/beaker/create_beaker_image.sh $(IMAGE_BASENAME)-nightly $(IMAGE_BASENAME)-v$(VERSION)-nightly $(BEAKER_WORKSPACE)

.PHONY : beaker-image-dev
beaker-image-dev : dev-image
./src/scripts/beaker/create_beaker_image.sh $(IMAGE_BASENAME)-dev $(IMAGE_BASENAME)-dev $(BEAKER_WORKSPACE)
./src/scripts/beaker/create_beaker_image.sh $(IMAGE_BASENAME)-dev $(IMAGE_BASENAME)-v$(VERSION_SHORT)-dev $(BEAKER_WORKSPACE)
./src/scripts/beaker/create_beaker_image.sh $(IMAGE_BASENAME)-dev $(IMAGE_BASENAME)-v$(VERSION)-dev $(BEAKER_WORKSPACE)

.PHONY : get-beaker-workspace
get-beaker-workspace :
@echo $(BEAKER_WORKSPACE)

.PHONY : get-full-beaker-image-name
get-full-beaker-image-name :
@./src/scripts/beaker/get_full_image_name.sh $(IMAGE_BASENAME) $(BEAKER_WORKSPACE)
@./src/scripts/beaker/get_full_image_name.sh $(IMAGE_NAME) $(BEAKER_WORKSPACE)
6 changes: 6 additions & 0 deletions docs/source/nn/attention.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
``nn.attention``
================

.. automodule:: olmo_core.nn.attention
:members:
:member-order: bysource
6 changes: 6 additions & 0 deletions docs/source/nn/feed_forward.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
``nn.feed_forward``
===================

.. automodule:: olmo_core.nn.feed_forward
:members:
:member-order: bysource
34 changes: 5 additions & 29 deletions docs/source/nn/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,38 +3,14 @@

.. automodule:: olmo_core.nn

Attention
---------

.. automodule:: olmo_core.nn.attention
:members:
:member-order: bysource

FeedForward
-----------

.. automodule:: olmo_core.nn.feed_forward
:members:
:member-order: bysource

RoPE
----

.. automodule:: olmo_core.nn.rope
:members:
:member-order: bysource

LayerNorms
----------

.. automodule:: olmo_core.nn.layer_norm
:members:
:member-order: bysource

.. toctree::
:maxdepth: 2
:caption: Submodules
:hidden:

attention
feed_forward
functional
layer_norm
moe
rope
transformer
6 changes: 6 additions & 0 deletions docs/source/nn/layer_norm.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
``nn.layer_norm``
=================

.. automodule:: olmo_core.nn.layer_norm
:members:
:member-order: bysource
6 changes: 6 additions & 0 deletions docs/source/nn/moe.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
``nn.moe``
==========

.. automodule:: olmo_core.nn.moe
:members:
:member-order: bysource
6 changes: 6 additions & 0 deletions docs/source/nn/rope.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
``nn.rope``
===========

.. automodule:: olmo_core.nn.rope
:members:
:member-order: bysource
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ authors = [
requires-python = ">=3.9"
license = { file = "LICENSE" }
dependencies = [
"numpy",
"numpy<2.0",
"torch>=2.4,<=2.6.0.dev20241009",
"cached-path",
"requests",
Expand Down
26 changes: 23 additions & 3 deletions src/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
#########################################################################
# Stable image
#########################################################################

# Base image comes with PyTorch, numpy, flash-attn
ARG BASE
FROM ${BASE} as stable

# Install torchao
# Install torchao.
ARG TORCHAO_VERSION
RUN pip install --no-cache-dir torchao==${TORCHAO_VERSION}
RUN pip install --no-cache-dir ${TORCHAO_VERSION}

# Install other dependencies, but not the source code.
# Install direct dependencies, but not source code.
COPY pyproject.toml .
COPY src/olmo_core/__init__.py src/olmo_core/__init__.py
COPY src/olmo_core/version.py src/olmo_core/version.py
Expand All @@ -16,7 +20,23 @@ RUN pip install --no-cache-dir '.[all]' && \

WORKDIR /app/olmo-core

#########################################################################
# Nightly image
#########################################################################

FROM stable as nightly

ARG NIGHTLY_VERSION
RUN pip install --no-cache-dir --pre torch==${NIGHTLY_VERSION}

#########################################################################
# Dev image
#########################################################################

FROM nightly as dev

# Install core dev dependencies.
ENV TORCH_CUDA_ARCH_LIST="8.0 9.0"
ENV GROUPED_GEMM_CUTLASS=1
ARG MEGABLOCKS_VERSION
RUN pip install --no-cache-dir "${MEGABLOCKS_VERSION}"
12 changes: 10 additions & 2 deletions src/olmo_core/internal/experiment.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import sys
from dataclasses import dataclass
from typing import Callable, Dict, List, cast
from typing import Callable, Dict, List, Optional, cast

from beaker import Beaker

Expand Down Expand Up @@ -232,6 +232,7 @@ def build_config(
model_config_builder: Callable[[CommonComponents], TransformerConfig],
optim_config_builder: Callable[[CommonComponents], AdamWConfig],
trainer_config_builder: Callable[[CommonComponents], TrainerConfig],
finalize_config: Optional[Callable[[ExperimentConfig], None]] = None,
) -> ExperimentConfig:
common = build_common_components(
script, cmd, run_name, cluster, overrides, global_batch_size=global_batch_size
Expand All @@ -253,7 +254,12 @@ def build_config(
dataset=common.dataset,
data_loader=common.data_loader,
trainer=trainer,
).merge(overrides)
)

if finalize_config is not None:
finalize_config(config)

config = config.merge(overrides)

if config.model.float8_config is not None and config.model.float8_config.enabled:
config.trainer.add_callback(
Expand Down Expand Up @@ -313,6 +319,7 @@ def main(
model_config_builder: Callable[[CommonComponents], TransformerConfig],
optim_config_builder: Callable[[CommonComponents], AdamWConfig],
trainer_config_builder: Callable[[CommonComponents], TrainerConfig],
finalize_config: Optional[Callable[[ExperimentConfig], None]] = None,
):
usage = f"""
[yellow]Usage:[/] [i blue]python[/] [i cyan]{sys.argv[0]}[/] [i b magenta]{'|'.join(SubCmd)}[/] [i b]RUN_NAME CLUSTER[/] [i][OVERRIDES...][/]
Expand Down Expand Up @@ -350,6 +357,7 @@ def main(
model_config_builder=model_config_builder,
optim_config_builder=optim_config_builder,
trainer_config_builder=trainer_config_builder,
finalize_config=finalize_config,
)

cmd.run(config)
5 changes: 5 additions & 0 deletions src/olmo_core/launch/beaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ class OLMoCoreBeakerImage(StrEnum):
Built with the latest compatible nightly version of PyTorch.
"""

dev = "olmo-core-dev"
"""
Like :data:`nightly` but includes experimental dependencies and the CUDA toolkit.
"""


@dataclass
class BeakerEnvVar(Config):
Expand Down
8 changes: 8 additions & 0 deletions src/olmo_core/nn/moe/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""
MoE layers. Requires `megablocks <https://github.com/databricks/megablocks>`_.
"""

from .config import MoEActivationFn, MoEConfig, MoEMLPImplementation, MoEType
from .layers import MoE

__all__ = ["MoE", "MoEConfig", "MoEType", "MoEActivationFn", "MoEMLPImplementation"]
Loading
Loading