diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 1a6aeb57..02a154bd 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -8,6 +8,12 @@ on: pull_request: branches: - main + paths: + - 'Makefile' + - 'pyproject.toml' + - 'src/olmo_core/version.py' + - 'src/Dockerfile' + - '.github/workflows/docker.yml' push: branches: - main @@ -16,15 +22,11 @@ on: jobs: beaker: - name: Beaker image (${{ matrix.version }}) - runs-on: ubuntu-latest - timeout-minutes: 20 + name: Beaker images + runs-on: ubuntu-latest-m + timeout-minutes: 60 env: BEAKER_TOKEN: ${{ secrets.BEAKER_TOKEN }} - strategy: - fail-fast: false - matrix: - version: [nightly, stable] steps: - uses: actions/checkout@v3 @@ -35,9 +37,13 @@ jobs: run: | echo "BEAKER_WORKSPACE=$(make get-beaker-workspace)" >> $GITHUB_ENV - - name: Build + - name: Build stable image run: | - make ${{ matrix.version }}-image + make stable-image + + - name: Build nightly image + run: | + make nightly-image - uses: allenai/setup-beaker@v2 if: env.BEAKER_TOKEN != '' @@ -45,7 +51,14 @@ jobs: token: ${{ env.BEAKER_TOKEN }} workspace: ${{ env.BEAKER_WORKSPACE }} - - name: Push + - name: Push stable image + if: env.BEAKER_TOKEN != '' && startsWith(github.ref, 'refs/tags/') + run: | + rm -rf /opt/hostedtoolcache # clear up some disk space + make beaker-image-stable + + - name: Push nightly image if: env.BEAKER_TOKEN != '' && startsWith(github.ref, 'refs/tags/') run: | - make beaker-image-${{ matrix.version }} + rm -rf /opt/hostedtoolcache # clear up some disk space + make beaker-image-nightly diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 59595b8d..cba7adb2 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -109,10 +109,27 @@ jobs: matrix: task: - name: Test (GPU) - run: pytest -v --color=yes --durations=3 -m gpu src/test/ --ignore-glob='src/test/distributed/checkpoint*' + image: olmo-core + gpus: 2 + run: | + pytest -v --color=yes --durations=3 -m gpu \ + --ignore-glob='src/test/distributed/checkpoint*' \ + --ignore-glob='src/test/nn/moe*' \ + src/test/ - name: Test checkpoint (GPU) - run: pytest -v --color=yes --durations=3 -m gpu src/test/distributed/checkpoint* + image: olmo-core + gpus: 2 + run: | + pytest -v --color=yes --durations=3 -m gpu \ + src/test/distributed/checkpoint* + + - name: Test MoE (GPU) + image: olmo-core-nightly + gpus: 1 + run: | + pytest -v --color=yes --durations=3 -m gpu \ + src/test/nn/moe* steps: - uses: actions/checkout@v3 @@ -142,7 +159,7 @@ jobs: - name: Get full image name if: env.BEAKER_TOKEN != '' run: - echo "BEAKER_IMAGE=$(make get-full-beaker-image-name)" >> $GITHUB_ENV + echo "BEAKER_IMAGE=$(make get-full-beaker-image-name IMAGE_NAME=${{ matrix.task.image }})" >> $GITHUB_ENV - name: GPU Tests uses: allenai/beaker-run-action@v1.2 @@ -160,7 +177,7 @@ jobs: priority: low preemptible: true resources: - gpuCount: 2 + gpuCount: ${{ matrix.task.gpus }} constraints: cluster: - ai2/allennlp-cirrascale diff --git a/CHANGELOG.md b/CHANGELOG.md index 5cf08629..b85e3a20 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,8 +12,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added Google Cloud support for `list_directory()` and `clear_directory()`. - Added `CometCallback` for logging training runs to Comet.ml. - Added `DataMixBase` class, to allow extending to new data mix groups. +- Added support for MoE-based models. - Added method `DataLoaderBase.get_mock_batch()`. - Trainer now starts with a dry-run of a fake batch created by `DataLoaderBase.get_mock_batch()`. +- Added `Callback.pre_backward()`, `.pre_eval_batch()`, and `.post_eval_batch()` methods. +- Added `Trainer.model_forward()`, `.get_losses()`, and `.eval_batch()` methods. ### Changed diff --git a/Makefile b/Makefile index 13bb794c..d5cda160 100644 --- a/Makefile +++ b/Makefile @@ -1,8 +1,11 @@ BASE_IMAGE = ghcr.io/allenai/pytorch:2.4.1-cuda12.1-python3.11 + # NOTE: when upgrading the nightly version you also need to upgrade the torch version specification # in 'pyproject.toml' to include that nightly version. NIGHTLY_VERSION = "2.6.0.dev20241009+cu121 --index-url https://download.pytorch.org/whl/nightly/cu121" -TORCHAO_VERSION = "0.5.0 --extra-index-url https://download.pytorch.org/whl/cu121" +TORCHAO_VERSION = "torchao==0.5.0 --extra-index-url https://download.pytorch.org/whl/cu121" +MEGABLOCKS_VERSION = "megablocks[gg] @ git+https://git@github.com/epwalsh/megablocks.git@epwalsh/deps" +CUDA_TOOLKIT_VERSION = 12.1.0 VERSION = $(shell python src/olmo_core/version.py) VERSION_SHORT = $(shell python src/olmo_core/version.py short) @@ -45,25 +48,33 @@ stable-image : docker build -f src/Dockerfile \ --build-arg BUILDKIT_INLINE_CACHE=1 \ --build-arg BASE=$(BASE_IMAGE) \ + --build-arg CUDA_TOOLKIT_VERSION=$(CUDA_TOOLKIT_VERSION) \ + --build-arg MEGABLOCKS_VERSION=$(MEGABLOCKS_VERSION) \ --build-arg TORCHAO_VERSION=$(TORCHAO_VERSION) \ --target stable \ + --progress plain \ -t $(IMAGE_BASENAME) . - -.PHONY : beaker-image-stable -beaker-image-stable : stable-image - ./src/scripts/beaker/create_beaker_image.sh $(IMAGE_BASENAME) $(IMAGE_BASENAME) $(BEAKER_WORKSPACE) - ./src/scripts/beaker/create_beaker_image.sh $(IMAGE_BASENAME) $(IMAGE_BASENAME)-v$(VERSION_SHORT) $(BEAKER_WORKSPACE) - ./src/scripts/beaker/create_beaker_image.sh $(IMAGE_BASENAME) $(IMAGE_BASENAME)-v$(VERSION) $(BEAKER_WORKSPACE) + echo "Built image '$(IMAGE_BASENAME)', size: $$(docker inspect -f '{{ .Size }}' $(IMAGE_BASENAME) | numfmt --to=si)" .PHONY : nightly-image nightly-image : docker build -f src/Dockerfile \ --build-arg BUILDKIT_INLINE_CACHE=1 \ --build-arg BASE=$(BASE_IMAGE) \ + --build-arg CUDA_TOOLKIT_VERSION=$(CUDA_TOOLKIT_VERSION) \ + --build-arg MEGABLOCKS_VERSION=$(MEGABLOCKS_VERSION) \ --build-arg TORCHAO_VERSION=$(TORCHAO_VERSION) \ --build-arg NIGHTLY_VERSION=$(NIGHTLY_VERSION) \ --target nightly \ + --progress plain \ -t $(IMAGE_BASENAME)-nightly . + echo "Built image '$(IMAGE_BASENAME)-nightly', size: $$(docker inspect -f '{{ .Size }}' $(IMAGE_BASENAME)-nightly | numfmt --to=si)" + +.PHONY : beaker-image-stable +beaker-image-stable : stable-image + ./src/scripts/beaker/create_beaker_image.sh $(IMAGE_BASENAME) $(IMAGE_BASENAME) $(BEAKER_WORKSPACE) + ./src/scripts/beaker/create_beaker_image.sh $(IMAGE_BASENAME) $(IMAGE_BASENAME)-v$(VERSION_SHORT) $(BEAKER_WORKSPACE) + ./src/scripts/beaker/create_beaker_image.sh $(IMAGE_BASENAME) $(IMAGE_BASENAME)-v$(VERSION) $(BEAKER_WORKSPACE) .PHONY : beaker-image-nightly beaker-image-nightly : nightly-image @@ -77,4 +88,4 @@ get-beaker-workspace : .PHONY : get-full-beaker-image-name get-full-beaker-image-name : - @./src/scripts/beaker/get_full_image_name.sh $(IMAGE_BASENAME) $(BEAKER_WORKSPACE) + @./src/scripts/beaker/get_full_image_name.sh $(IMAGE_NAME) $(BEAKER_WORKSPACE) diff --git a/README.md b/README.md index ff70c9fd..995f7636 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,12 @@ First install [PyTorch](https://pytorch.org) according to the instructions speci pip install ai2-olmo-core ``` +## API stability + +Even though this library is under rapid development we are trying hard to adhere to [Semantic Versioning](https://semver.org/spec/v2.0.0.html) with every release except for features that are explicitly marked as beta features. Those features will be tagged like this in the [API docs](https://olmo-core.readthedocs.io/en/latest/): + +![image](https://github.com/user-attachments/assets/c666686d-3ae6-4c88-8381-befd698d3fd0) + ## Official training scripts Official training scripts for various model sizes can be found in [`src/scripts/train/`](https://github.com/allenai/OLMo-core/tree/main/src/scripts/train). diff --git a/docs/source/nn/attention.rst b/docs/source/nn/attention.rst new file mode 100644 index 00000000..1a806668 --- /dev/null +++ b/docs/source/nn/attention.rst @@ -0,0 +1,6 @@ +``nn.attention`` +================ + +.. automodule:: olmo_core.nn.attention + :members: + :member-order: bysource diff --git a/docs/source/nn/feed_forward.rst b/docs/source/nn/feed_forward.rst new file mode 100644 index 00000000..eee1ad53 --- /dev/null +++ b/docs/source/nn/feed_forward.rst @@ -0,0 +1,6 @@ +``nn.feed_forward`` +=================== + +.. automodule:: olmo_core.nn.feed_forward + :members: + :member-order: bysource diff --git a/docs/source/nn/index.rst b/docs/source/nn/index.rst index 9f84f3f1..6f23dfc8 100644 --- a/docs/source/nn/index.rst +++ b/docs/source/nn/index.rst @@ -3,38 +3,14 @@ .. automodule:: olmo_core.nn -Attention ---------- - -.. automodule:: olmo_core.nn.attention - :members: - :member-order: bysource - -FeedForward ------------ - -.. automodule:: olmo_core.nn.feed_forward - :members: - :member-order: bysource - -RoPE ----- - -.. automodule:: olmo_core.nn.rope - :members: - :member-order: bysource - -LayerNorms ----------- - -.. automodule:: olmo_core.nn.layer_norm - :members: - :member-order: bysource - .. toctree:: :maxdepth: 2 :caption: Submodules - :hidden: + attention + feed_forward functional + layer_norm + moe + rope transformer diff --git a/docs/source/nn/layer_norm.rst b/docs/source/nn/layer_norm.rst new file mode 100644 index 00000000..ed4371bd --- /dev/null +++ b/docs/source/nn/layer_norm.rst @@ -0,0 +1,6 @@ +``nn.layer_norm`` +================= + +.. automodule:: olmo_core.nn.layer_norm + :members: + :member-order: bysource diff --git a/docs/source/nn/moe.rst b/docs/source/nn/moe.rst new file mode 100644 index 00000000..cd404a9a --- /dev/null +++ b/docs/source/nn/moe.rst @@ -0,0 +1,6 @@ +``nn.moe`` +========== + +.. automodule:: olmo_core.nn.moe + :members: + :member-order: bysource diff --git a/docs/source/nn/rope.rst b/docs/source/nn/rope.rst new file mode 100644 index 00000000..3224400c --- /dev/null +++ b/docs/source/nn/rope.rst @@ -0,0 +1,6 @@ +``nn.rope`` +=========== + +.. automodule:: olmo_core.nn.rope + :members: + :member-order: bysource diff --git a/pyproject.toml b/pyproject.toml index 95ec06a8..5af9c589 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ authors = [ requires-python = ">=3.9" license = { file = "LICENSE" } dependencies = [ - "numpy", + "numpy<2.0", "torch>=2.4,<=2.6.0.dev20241009", "cached-path", "requests", diff --git a/src/Dockerfile b/src/Dockerfile index 38e4072c..1a7255a5 100644 --- a/src/Dockerfile +++ b/src/Dockerfile @@ -1,12 +1,40 @@ # Base image comes with PyTorch, numpy, flash-attn ARG BASE + +######################################################################### +# Build image +######################################################################### + +FROM ${BASE} as build + +WORKDIR /app/build + +# Install CUDA toolkit. +ARG CUDA_TOOLKIT_VERSION +RUN conda install -y -c nvidia cuda-toolkit==${CUDA_TOOLKIT_VERSION} + +# Build megablocks and grouped-gemm. +ENV TORCH_CUDA_ARCH_LIST="8.0 9.0" +ENV GROUPED_GEMM_CUTLASS=1 +ARG MEGABLOCKS_VERSION +RUN pip wheel --no-build-isolation --no-cache-dir "${MEGABLOCKS_VERSION}" \ + && rm -rf torch-*.whl numpy-*.whl triton-*.whl + +######################################################################### +# Stable image +######################################################################### + FROM ${BASE} as stable -# Install torchao +# Install torchao. ARG TORCHAO_VERSION -RUN pip install --no-cache-dir torchao==${TORCHAO_VERSION} +RUN pip install --no-cache-dir ${TORCHAO_VERSION} -# Install other dependencies, but not the source code. +# Copy and install wheels from build image. +COPY --from=build /app/build /app/build +RUN pip install --no-cache-dir /app/build/* + +# Install direct dependencies, but not source code. COPY pyproject.toml . COPY src/olmo_core/__init__.py src/olmo_core/__init__.py COPY src/olmo_core/version.py src/olmo_core/version.py @@ -16,6 +44,10 @@ RUN pip install --no-cache-dir '.[all]' && \ WORKDIR /app/olmo-core +######################################################################### +# Nightly image +######################################################################### + FROM stable as nightly ARG NIGHTLY_VERSION diff --git a/src/examples/train.py b/src/examples/train.py index 727b5756..deec8307 100644 --- a/src/examples/train.py +++ b/src/examples/train.py @@ -17,9 +17,9 @@ NumpyDatasetType, TokenizerConfig, ) -from olmo_core.distributed.parallel import DataParallelConfig, DataParallelType +from olmo_core.distributed.parallel import DataParallelType from olmo_core.distributed.utils import init_hybrid_shard_mesh -from olmo_core.nn.transformer import TransformerConfig +from olmo_core.nn.transformer import TransformerConfig, TransformerDataParallelConfig from olmo_core.optim import AdamWConfig, CosWithWarmup, OptimGroupOverride from olmo_core.train import ( Duration, @@ -58,7 +58,7 @@ def build_config(run_name: str, overrides: List[str]) -> ExperimentConfig: model_config = TransformerConfig.llama2_271M( vocab_size=tokenizer_config.padded_vocab_size(), # a little bigger than actual vocab size to make it a multiple of 128 compile=True, - dp_config=DataParallelConfig( + dp_config=TransformerDataParallelConfig( name=DataParallelType.fsdp, param_dtype=DType.bfloat16, reduce_dtype=DType.float32 ), ) diff --git a/src/olmo_core/config.py b/src/olmo_core/config.py index 3e02b556..27f509a5 100644 --- a/src/olmo_core/config.py +++ b/src/olmo_core/config.py @@ -217,5 +217,14 @@ class DType(StrEnum): float32 = "float32" bfloat16 = "bfloat16" + @classmethod + def from_pt(cls, dtype: torch.dtype) -> "DType": + if dtype == torch.float32: + return DType.float32 + elif dtype == torch.bfloat16: + return DType.bfloat16 + else: + raise NotImplementedError(dtype) + def as_pt(self) -> torch.dtype: return getattr(torch, self) diff --git a/src/olmo_core/doc_utils.py b/src/olmo_core/doc_utils.py new file mode 100644 index 00000000..00bc4551 --- /dev/null +++ b/src/olmo_core/doc_utils.py @@ -0,0 +1,22 @@ +from typing import TypeVar + +T = TypeVar("T") + + +def beta_feature(f: T) -> T: + """ + Mark a class or function as a beta feature. + """ + if f.__doc__ is None: + f.__doc__ = "" + + f.__doc__ += """ + + .. warning:: + This is a beta feature! The API is subject to change even with minor and patch releases. + If you choose to use this feature please read the `CHANGELOG `_ + before upgrading your version of this library. + + """ + + return f diff --git a/src/olmo_core/internal/experiment.py b/src/olmo_core/internal/experiment.py index 42d0ca1e..efe15a4f 100644 --- a/src/olmo_core/internal/experiment.py +++ b/src/olmo_core/internal/experiment.py @@ -1,7 +1,7 @@ import logging import sys from dataclasses import dataclass -from typing import Callable, Dict, List, cast +from typing import Callable, Dict, List, Optional, cast from beaker import Beaker @@ -232,6 +232,7 @@ def build_config( model_config_builder: Callable[[CommonComponents], TransformerConfig], optim_config_builder: Callable[[CommonComponents], AdamWConfig], trainer_config_builder: Callable[[CommonComponents], TrainerConfig], + finalize_config: Optional[Callable[[ExperimentConfig], None]] = None, ) -> ExperimentConfig: common = build_common_components( script, cmd, run_name, cluster, overrides, global_batch_size=global_batch_size @@ -253,7 +254,12 @@ def build_config( dataset=common.dataset, data_loader=common.data_loader, trainer=trainer, - ).merge(overrides) + ) + + if finalize_config is not None: + finalize_config(config) + + config = config.merge(overrides) if config.model.float8_config is not None and config.model.float8_config.enabled: config.trainer.add_callback( @@ -313,6 +319,7 @@ def main( model_config_builder: Callable[[CommonComponents], TransformerConfig], optim_config_builder: Callable[[CommonComponents], AdamWConfig], trainer_config_builder: Callable[[CommonComponents], TrainerConfig], + finalize_config: Optional[Callable[[ExperimentConfig], None]] = None, ): usage = f""" [yellow]Usage:[/] [i blue]python[/] [i cyan]{sys.argv[0]}[/] [i b magenta]{'|'.join(SubCmd)}[/] [i b]RUN_NAME CLUSTER[/] [i][OVERRIDES...][/] @@ -350,6 +357,7 @@ def main( model_config_builder=model_config_builder, optim_config_builder=optim_config_builder, trainer_config_builder=trainer_config_builder, + finalize_config=finalize_config, ) cmd.run(config) diff --git a/src/olmo_core/nn/moe/__init__.py b/src/olmo_core/nn/moe/__init__.py new file mode 100644 index 00000000..742ca33d --- /dev/null +++ b/src/olmo_core/nn/moe/__init__.py @@ -0,0 +1,8 @@ +""" +MoE layers. Requires `megablocks `_. +""" + +from .config import MoEActivationFn, MoEConfig, MoEMLPImplementation, MoEType +from .layers import MoE + +__all__ = ["MoE", "MoEConfig", "MoEType", "MoEActivationFn", "MoEMLPImplementation"] diff --git a/src/olmo_core/nn/moe/config.py b/src/olmo_core/nn/moe/config.py new file mode 100644 index 00000000..cbde42fb --- /dev/null +++ b/src/olmo_core/nn/moe/config.py @@ -0,0 +1,222 @@ +from dataclasses import dataclass +from functools import partial +from typing import Callable + +import torch +import torch.nn.functional as F + +from olmo_core.config import Config, DType, StrEnum +from olmo_core.doc_utils import beta_feature + +from .layers import MoE as MoEWrapper + + +class MoEType(StrEnum): + """ + An enumeration of MoE layer types. + """ + + default = "default" + """ + The default version. + """ + + dropless = "dropless" + """ + The `dropless + `_ version. + """ + + +class MoEActivationFn(StrEnum): + """ + An enumeration of the different MoE activation functions available. + """ + + swiglu = "swiglu" + """ + SwiGLU. + """ + gelu = "gelu" + """ + GeLU. + """ + gelu_tanh = "gelu_tanh" + """ + GeLU with tanh approximation. + """ + relu = "relu" + """ + ReLU. + """ + + def build(self) -> Callable[[torch.Tensor], torch.Tensor]: + if self == MoEActivationFn.swiglu: + return partial(F.silu, inplace=False) + elif self == MoEActivationFn.gelu: + return partial(F.gelu, approximate="none") + elif self == MoEActivationFn.gelu_tanh: + return partial(F.gelu, approximate="tanh") + elif self == MoEActivationFn.relu: + return partial(F.relu, inplace=False) + else: + raise NotImplementedError(self) + + +class MoEMLPImplementation(StrEnum): + """ + An enumeration of the different MoE implementations. + """ + + sparse = "sparse" + """ + Sparse implementation. + """ + grouped = "grouped" + """ + Requires the `grouped GEMM + `_ package. + """ + + +@beta_feature +@dataclass +class MoEConfig(Config): + """ + Configuration class for building MoE layers. + + .. important:: + Requires `megablocks `_. + """ + + name: MoEType = MoEType.default + """ + The MoE implementation. + """ + hidden_size: int = 4096 + """ + The MLP hidden size. + """ + activation_fn: MoEActivationFn = MoEActivationFn.swiglu + """ + The activation function to use. + """ + mlp_implementation: MoEMLPImplementation = MoEMLPImplementation.sparse + """ + The MLP implementation. + """ + memory_optimized_mlp: bool = False + """ + Use the memory-optimized version of the MLP. + """ + num_experts: int = 8 + """ + The number of experts to use in the MoE block. + """ + top_k: int = 2 + """ + The number of experts to select for each token. + """ + capacity_factor: int = 1 + """ + The capacity factor to use in the MoE block. Only applies if not using :data:`MoEType.dropless`. + """ + bias: bool = True + """ + Include bias terms. + """ + loss_weight: float = 0.1 + """ + The weight to use for the MoE load balancing loss. + """ + zloss_weight: float = 0.0 + """ + Weight for MoE router z-loss where None means no router z-loss. 0.001 is a common value. + """ + zloss_in_fp32: bool = False + """ + Whether to compute the z-loss in FP32. + """ + shared_expert: bool = False + """ + Whether to have an always-used expert like in `DeepSeekMoE + `_. + """ + lbl_in_fp32: bool = False + """ + Whether to perform load balancing in FP32. + """ + num_layers: int = 1 + """ + The total number of MoE layers. + """ + dtype: DType = DType.float32 + """ + The data type for the parameters. + """ + + def num_params(self, d_model: int) -> int: + num_params = 0 + + # Router. + num_params += self.num_experts * d_model + + # Experts. + num_params += self.num_experts * (2 * d_model * self.hidden_size) + if self.name == MoEType.dropless and "glu" in self.activation_fn.lower(): + num_params += self.num_experts * d_model * self.hidden_size + + # Bias. + if self.bias: + num_params += d_model + + return num_params + + def as_megablocks_args(self, *, d_model: int, init_device: str = "cpu"): + from megablocks.layers.arguments import Arguments # type: ignore + + return Arguments( + hidden_size=d_model, + activation_fn=self.activation_fn.build(), + mlp_type="glu" if "glu" in self.activation_fn.lower() else "mlp", + mlp_impl=self.mlp_implementation, + memory_optimized_mlp=self.memory_optimized_mlp, + ffn_hidden_size=self.hidden_size, + moe_num_experts=self.num_experts, + moe_top_k=self.top_k, + moe_capacity_factor=self.capacity_factor, + moe_loss_weight=self.loss_weight, + moe_zloss_weight=self.zloss_weight, + moe_zloss_in_fp32=self.zloss_in_fp32, + moe_lbl_in_fp32=self.lbl_in_fp32, + shared_expert=self.shared_expert, + bias=self.bias, + return_bias=False, + num_layers=self.num_layers, + device=torch.device(init_device), + fp16=False, + bf16=self.dtype == DType.bfloat16, + ) + + def build(self, *, d_model: int, init_device: str = "cpu") -> MoEWrapper: + """ + Build the MoE layer. + + :param d_model: The model dimensionality. + :param init_device: The device to initialize weights on. + """ + try: + from megablocks.layers.dmoe import dMoE + from megablocks.layers.moe import MoE + except ImportError as e: + raise ImportError( + "megablocks is not installed. Please install it to use MoE layers" + ) from e + + args = self.as_megablocks_args(d_model=d_model, init_device=init_device) + if self.name == MoEType.default: + return MoEWrapper(args, MoE(args)) + elif self.name == MoEType.dropless: + return MoEWrapper(args, dMoE(args)) + else: + raise NotImplementedError(self.name) diff --git a/src/olmo_core/nn/moe/layers.py b/src/olmo_core/nn/moe/layers.py new file mode 100644 index 00000000..42efb7ce --- /dev/null +++ b/src/olmo_core/nn/moe/layers.py @@ -0,0 +1,101 @@ +from typing import Optional + +import torch +import torch.nn as nn + +from olmo_core.doc_utils import beta_feature + + +@beta_feature +class MoE(nn.Module): + """ + A thin wrapper around `megablocks `_ MoE layers. + + .. tip:: + Use :class:`MoEConfig` to build instances of this module. + + .. important:: + This should always be used in conjunction with the + :class:`~olmo_core.train.callbacks.MoEHandlerCallback` for training. + """ + + def __init__(self, args, inner): + super().__init__() + self.args = args + self.inner = inner + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Run the MoE on the input. + + :param x: A tensor of shape ``(batch_size, sequence_length, d_model)``. + """ + return self.inner(x) + + def get_load_balancing_loss(self) -> Optional[torch.Tensor]: + """ + Get the batched load-balancing loss from the internal buffers. + + .. important:: + This method will clear the internal buffers so can only be called once per forward pass. + """ + from megablocks.layers.moe import ( # type: ignore + batched_load_balancing_loss, + clear_load_balancing_loss, + ) + + if isinstance(lb_loss := batched_load_balancing_loss(self.args), torch.Tensor): + clear_load_balancing_loss() + return lb_loss + else: + return None + + def get_router_z_loss(self) -> Optional[torch.Tensor]: + """ + Get the batched router Z-loss from the internal buffers. + + .. important:: + This method will clear the internal buffers so can only be called once per forward pass. + """ + from megablocks.layers.router import ( # type: ignore + batched_router_zloss, + clear_router_zloss, + ) + + if self.args.moe_zloss_weight != 0 and isinstance( + (z_loss_per_layer := batched_router_zloss(self.args)), torch.Tensor + ): + z_loss = z_loss_per_layer.sum() / self.args.num_layers + clear_router_zloss() + return z_loss + else: + return None + + def get_loss(self) -> Optional[torch.Tensor]: + """ + Get the batched combined load-balancing loss and router Z-loss from the internal buffers. + + .. important:: + This method will clear the internal buffers so can only be called once per forward pass. + """ + loss: Optional[torch.Tensor] = None + if (lb_loss := self.get_load_balancing_loss()) is not None: + loss = lb_loss + + if (rz_loss := self.get_router_z_loss()) is not None: + if loss is not None: + loss += rz_loss + else: + loss = rz_loss + + return loss + + def clear_losses(self): + """ + Clear internal loss buffers. + """ + from megablocks.layers.moe import clear_load_balancing_loss # type: ignore + from megablocks.layers.router import clear_router_zloss # type: ignore + + clear_load_balancing_loss() + clear_router_zloss() diff --git a/src/olmo_core/nn/transformer/__init__.py b/src/olmo_core/nn/transformer/__init__.py index 5e0ffc97..690769d9 100644 --- a/src/olmo_core/nn/transformer/__init__.py +++ b/src/olmo_core/nn/transformer/__init__.py @@ -3,8 +3,11 @@ """ from .block import ( + MoEReorderedNormTransformerBlock, + MoETransformerBlock, ReorderedNormTransformerBlock, TransformerBlock, + TransformerBlockBase, TransformerBlockConfig, TransformerBlockType, ) @@ -12,7 +15,10 @@ from .model import ( Transformer, TransformerActivationCheckpointingConfig, + TransformerActivationCheckpointingMode, TransformerConfig, + TransformerDataParallelConfig, + TransformerDataParallelWrappingStrategy, ) __all__ = [ @@ -20,8 +26,14 @@ "Transformer", "TransformerBlockType", "TransformerBlockConfig", + "TransformerBlockBase", "TransformerBlock", "ReorderedNormTransformerBlock", + "MoETransformerBlock", + "MoEReorderedNormTransformerBlock", + "TransformerDataParallelConfig", + "TransformerDataParallelWrappingStrategy", "TransformerActivationCheckpointingConfig", + "TransformerActivationCheckpointingMode", "InitMethod", ] diff --git a/src/olmo_core/nn/transformer/block.py b/src/olmo_core/nn/transformer/block.py index f4c84618..810273bc 100644 --- a/src/olmo_core/nn/transformer/block.py +++ b/src/olmo_core/nn/transformer/block.py @@ -1,3 +1,4 @@ +from abc import abstractmethod from dataclasses import dataclass from typing import Optional @@ -5,11 +6,13 @@ import torch.nn as nn from olmo_core.config import Config, StrEnum +from olmo_core.exceptions import OLMoConfigurationError from ..attention import AttentionConfig from ..buffer_cache import BufferCache from ..feed_forward import FeedForwardConfig from ..layer_norm import LayerNormConfig +from ..moe import MoEConfig class TransformerBlockType(StrEnum): @@ -27,6 +30,16 @@ class TransformerBlockType(StrEnum): :class:`ReorderedNormTransformerBlock` """ + moe = "moe" + """ + :class:`MoETransformerBlock` + """ + + moe_reordered_norm = "moe" + """ + :class:`MoEReorderedNormTransformerBlock` + """ + @dataclass class TransformerBlockConfig(Config): @@ -35,10 +48,29 @@ class TransformerBlockConfig(Config): """ attention: AttentionConfig - feed_forward: FeedForwardConfig + """ + The attention config. + """ layer_norm: LayerNormConfig + """ + The layer norm config. + """ + feed_forward: Optional[FeedForwardConfig] = None + """ + The feed-forward config, required for non-MoE blocks. + """ + feed_forward_moe: Optional[MoEConfig] = None + """ + The config for the MoE feed-forward layer. Required for MoE blocks. + """ name: TransformerBlockType = TransformerBlockType.default + """ + The block type. + """ dropout: float = 0.0 + """ + Dropout probability. + """ def build( self, @@ -47,27 +79,90 @@ def build( block_idx: int, init_device: str = "cpu", cache: Optional[BufferCache] = None, - ) -> "TransformerBlock": - kwargs = self.as_dict(exclude_none=True, recurse=False) - kwargs.pop("name") - kwargs.update( - dict( + ) -> "TransformerBlockBase": + if self.name == TransformerBlockType.default: + if self.feed_forward is None: + raise OLMoConfigurationError("'feed_forward' config is required") + return TransformerBlock( d_model=d_model, block_idx=block_idx, + attention=self.attention, + feed_forward=self.feed_forward, + layer_norm=self.layer_norm, + dropout=self.dropout, init_device=init_device, cache=cache, ) - ) - - if self.name == TransformerBlockType.default: - return TransformerBlock(**kwargs) elif self.name == TransformerBlockType.reordered_norm: - return ReorderedNormTransformerBlock(**kwargs) + if self.feed_forward is None: + raise OLMoConfigurationError("'feed_forward' config is required") + return ReorderedNormTransformerBlock( + d_model=d_model, + block_idx=block_idx, + attention=self.attention, + feed_forward=self.feed_forward, + layer_norm=self.layer_norm, + dropout=self.dropout, + init_device=init_device, + cache=cache, + ) + elif self.name == TransformerBlockType.moe: + if self.feed_forward_moe is None: + raise OLMoConfigurationError("'feed_forward_moe' config is required for MoE blocks") + return MoETransformerBlock( + d_model=d_model, + block_idx=block_idx, + attention=self.attention, + feed_forward_moe=self.feed_forward_moe, + layer_norm=self.layer_norm, + dropout=self.dropout, + init_device=init_device, + cache=cache, + ) + elif self.name == TransformerBlockType.moe_reordered_norm: + if self.feed_forward_moe is None: + raise OLMoConfigurationError("'feed_forward_moe' config is required for MoE blocks") + return MoEReorderedNormTransformerBlock( + d_model=d_model, + block_idx=block_idx, + attention=self.attention, + feed_forward_moe=self.feed_forward_moe, + layer_norm=self.layer_norm, + dropout=self.dropout, + init_device=init_device, + cache=cache, + ) else: raise NotImplementedError(self.name) -class TransformerBlock(nn.Module): +class TransformerBlockBase(nn.Module): + """ + Base class for transformer block implementations. + """ + + @abstractmethod + def forward( + self, + x: torch.Tensor, + max_doc_len: Optional[int] = None, + cu_doc_lens: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Run the block on the input ``x``. + + :param x: The input of shape ``(batch_size, seq_len, d_model)``. + :param max_doc_len: The maximum document length in the input ``x``. + Required together with ``cu_doc_lens`` when using intra-document masking. + :param cu_doc_lens: Cumulative document lengths in the input ``x``, a 1D + :class:`torch.int32` tensor that should always have one more element than there + are documents (the first element in the tensor should always be ``0``). + Required together with ``max_doc_len`` when using intra-document masking. + """ + raise NotImplementedError + + +class TransformerBlock(TransformerBlockBase): """ A typical "Llama-style" transformer block implementation. @@ -107,17 +202,6 @@ def forward( max_doc_len: Optional[int] = None, cu_doc_lens: Optional[torch.Tensor] = None, ) -> torch.Tensor: - """ - Run the block on the input ``x``. - - :param x: The input of shape ``(batch_size, seq_len, d_model)``. - :param max_doc_len: The maximum document length in the input ``x``. - Required together with ``cu_doc_lens`` when using intra-document masking. - :param cu_doc_lens: Cumulative document lengths in the input ``x``, a 1D - :class:`torch.int32` tensor that should always have one more element than there - are documents (the first element in the tensor should always be ``0``). - Required together with ``max_doc_len`` when using intra-document masking. - """ h = x + self.dropout( self.attention(self.attention_norm(x), max_doc_len=max_doc_len, cu_doc_lens=cu_doc_lens) ) @@ -141,3 +225,66 @@ def forward( self.attention_norm(self.attention(x, max_doc_len=max_doc_len, cu_doc_lens=cu_doc_lens)) ) return h + self.dropout(self.feed_forward_norm(self.feed_forward(h))) + + +class MoETransformerBlock(TransformerBlockBase): + """ + Like :class:`TransformerBlock` except that the dense :class:`~olmo_core.nn.feed_forward.FeedForward` + module is replaced with a mixture-of-experts (MoE). + """ + + def __init__( + self, + *, + d_model: int, + block_idx: int, + attention: AttentionConfig, + feed_forward_moe: MoEConfig, + layer_norm: LayerNormConfig, + dropout: float = 0.0, + init_device: str = "cpu", + cache: Optional[BufferCache] = None, + ): + super().__init__() + self.d_model = d_model + self.block_idx = block_idx + self.attention = attention.build(d_model, init_device=init_device, cache=cache) + self.attention_norm = layer_norm.build(d_model, init_device=init_device) + self.feed_forward_moe = feed_forward_moe.build(d_model=d_model, init_device=init_device) + self.feed_forward_norm = layer_norm.build(d_model, init_device=init_device) + self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity() + + def forward( + self, + x: torch.Tensor, + max_doc_len: Optional[int] = None, + cu_doc_lens: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Run the block on the input ``x``. + + Parameters are the same as :meth:`TransformerBlock.forward()`. + """ + h = x + self.dropout( + self.attention(self.attention_norm(x), max_doc_len=max_doc_len, cu_doc_lens=cu_doc_lens) + ) + return h + self.dropout(self.feed_forward_moe(self.feed_forward_norm(h))) + + +class MoEReorderedNormTransformerBlock(MoETransformerBlock): + """ + Like :class:`MoETransformerBlock` except that the attention norm is applied on the output + of attention instead of the input, and likewise the feed-forward norm is applied on the + output of the feed-forward MoE instead of the input. + """ + + def forward( + self, + x: torch.Tensor, + max_doc_len: Optional[int] = None, + cu_doc_lens: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + h = x + self.dropout( + self.attention_norm(self.attention(x, max_doc_len=max_doc_len, cu_doc_lens=cu_doc_lens)) + ) + return h + self.dropout(self.feed_forward_norm(self.feed_forward_moe(h))) diff --git a/src/olmo_core/nn/transformer/init.py b/src/olmo_core/nn/transformer/init.py index 32d8e94f..c022cb3f 100644 --- a/src/olmo_core/nn/transformer/init.py +++ b/src/olmo_core/nn/transformer/init.py @@ -7,6 +7,7 @@ from ..attention import Attention, FusedAttention from ..feed_forward import FeedForward +from ..moe import MoE class InitMethod(StrEnum): @@ -94,3 +95,36 @@ def init_feed_forward( self._init_linear(m.w1, std=0.02, generator=generator) self._init_linear(m.w2, std=std, generator=generator) self._init_linear(m.w3, std=std, generator=generator) + + def init_feed_forward_moe( + self, + m: MoE, + *, + block_idx: int, + num_blocks: int, + generator: Optional[torch.Generator] = None, + ): + std = 0.02 + if self == InitMethod.llama: + std = 0.02 / (2 * num_blocks) ** 0.5 + elif self == InitMethod.llama_depth: + std = 0.02 / (2 * (block_idx + 1)) ** 0.5 + + self._init_linear(m.inner.router.layer, std=0.02, generator=generator) + nn.init.trunc_normal_( + m.inner.experts.mlp.w1, mean=0.0, std=0.02, a=-3 * std, b=3 * std, generator=generator + ) + nn.init.trunc_normal_( + m.inner.experts.mlp.w2, mean=0.0, std=std, a=-3 * std, b=3 * std, generator=generator + ) + if hasattr(m.inner.experts.mlp, "v1"): + nn.init.trunc_normal_( + m.inner.experts.mlp.v1, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + generator=generator, + ) + if (bias := getattr(m.inner.experts, "bias", None)) is not None: + nn.init.zeros_(bias) diff --git a/src/olmo_core/nn/transformer/model.py b/src/olmo_core/nn/transformer/model.py index abb179bb..4bb3076e 100644 --- a/src/olmo_core/nn/transformer/model.py +++ b/src/olmo_core/nn/transformer/model.py @@ -13,6 +13,7 @@ from olmo_core.config import Config, DType, StrEnum from olmo_core.data.utils import get_cumulative_document_lengths from olmo_core.distributed.parallel import DataParallelConfig, DataParallelType +from olmo_core.doc_utils import beta_feature from olmo_core.exceptions import OLMoConfigurationError from olmo_core.float8 import Float8Config from olmo_core.utils import get_default_device, has_flash_attn @@ -28,6 +29,8 @@ __all__ = [ "TransformerConfig", "Transformer", + "TransformerDataParallelConfig", + "TransformerDataParallelWrappingStrategy", "TransformerActivationCheckpointingConfig", "TransformerActivationCheckpointingMode", ] @@ -36,7 +39,37 @@ log = logging.getLogger(__name__) +class TransformerDataParallelWrappingStrategy(StrEnum): + """ + An enumeration of the different wrapping strategy for the data parallel implementations. + """ + + full = "full" + """ + Wrap each block (only applies to FSDP). + """ + fine_grained = "fine_grained" + """ + Wrap certain modules within each block in addition to wrapping each block (only applies to FSDP). + """ + + +@dataclass +class TransformerDataParallelConfig(DataParallelConfig): + wrapping_strategy: TransformerDataParallelWrappingStrategy = ( + TransformerDataParallelWrappingStrategy.full + ) + """ + Wrapping strategy. + """ + + +@beta_feature class TransformerActivationCheckpointingMode(StrEnum): + """ + An enumeration of the different activation checkpointing modes. + """ + full = "full" """Checkpoint every block.""" selected_blocks = "selected_blocks" @@ -45,6 +78,7 @@ class TransformerActivationCheckpointingMode(StrEnum): """Checkpoint only selected modules.""" +@beta_feature @dataclass class TransformerActivationCheckpointingConfig(Config): """ @@ -104,7 +138,7 @@ class TransformerConfig(Config): init_method: InitMethod = InitMethod.normal init_seed: int = 0 compile: bool = False - dp_config: Optional[DataParallelConfig] = None + dp_config: Optional[TransformerDataParallelConfig] = None ac_config: Optional[TransformerActivationCheckpointingConfig] = None float8_config: Optional[Float8Config] = None @@ -176,6 +210,7 @@ def build( if self.dp_config.param_dtype is not None else None, reduce_dtype=self.dp_config.reduce_dtype.as_pt(), + wrapping_strategy=self.dp_config.wrapping_strategy, ) elif self.dp_config.name == DataParallelType.ddp: model.apply_ddp(dp_mesh=dp_mesh, compile_enabled=self.compile) @@ -237,9 +272,14 @@ def layer_norm_params(layer_norm: LayerNormConfig) -> int: block_params += layer_norm_params(self.block.layer_norm) # Block feed forward. - block_params += 3 * self.d_model * self.block.feed_forward.hidden_size - if self.block.feed_forward.bias: - block_params += 2 * self.block.feed_forward.hidden_size + self.d_model + if "moe" not in self.block.name: + assert self.block.feed_forward is not None + block_params += 3 * self.d_model * self.block.feed_forward.hidden_size + if self.block.feed_forward.bias: + block_params += 2 * self.block.feed_forward.hidden_size + self.d_model + else: + assert self.block.feed_forward_moe is not None + block_params += self.block.feed_forward_moe.num_params(self.d_model) # Block feed forward norm. block_params += layer_norm_params(self.block.layer_norm) @@ -334,8 +374,8 @@ def llama2_271M(cls, vocab_size: int, **kwargs) -> "TransformerConfig": return cls.llama_like( d_model=1024, vocab_size=vocab_size, - n_layers=16, - n_heads=8, + n_layers=kwargs.pop("n_layers", 16), + n_heads=kwargs.pop("n_heads", 8), rope_theta=kwargs.pop("rope_theta", 10_000), **kwargs, ) @@ -348,8 +388,8 @@ def llama2_1B(cls, vocab_size: int, **kwargs) -> "TransformerConfig": return cls.llama_like( d_model=2048, vocab_size=vocab_size, - n_layers=18, - n_heads=16, + n_layers=kwargs.pop("n_layers", 18), + n_heads=kwargs.pop("n_heads", 16), rope_theta=kwargs.pop("rope_theta", 10_000), **kwargs, ) @@ -362,8 +402,8 @@ def llama2_7B(cls, vocab_size: int, **kwargs) -> "TransformerConfig": return cls.llama_like( d_model=4096, vocab_size=vocab_size, - n_layers=32, - n_heads=32, + n_layers=kwargs.pop("n_layers", 32), + n_heads=kwargs.pop("n_heads", 32), rope_theta=kwargs.pop("rope_theta", 10_000), **kwargs, ) @@ -376,8 +416,8 @@ def llama2_13B(cls, vocab_size: int, **kwargs) -> "TransformerConfig": return cls.llama_like( d_model=5120, vocab_size=vocab_size, - n_layers=40, - n_heads=40, + n_layers=kwargs.pop("n_layers", 40), + n_heads=kwargs.pop("n_heads", 40), rope_theta=kwargs.pop("rope_theta", 10_000), **kwargs, ) @@ -390,8 +430,8 @@ def llama2_26B(cls, vocab_size: int, **kwargs) -> "TransformerConfig": return cls.llama_like( d_model=5120, vocab_size=vocab_size, - n_layers=80, - n_heads=40, + n_layers=kwargs.pop("n_layers", 80), + n_heads=kwargs.pop("n_heads", 40), rope_theta=kwargs.pop("rope_theta", 10_000), **kwargs, ) @@ -404,9 +444,9 @@ def llama2_70B(cls, vocab_size: int, **kwargs) -> "TransformerConfig": return cls.llama_like( d_model=8192, vocab_size=vocab_size, - n_layers=80, - n_heads=64, - n_kv_heads=8, + n_layers=kwargs.pop("n_layers", 80), + n_heads=kwargs.pop("n_heads", 64), + n_kv_heads=kwargs.pop("n_kv_heads", 8), rope_theta=kwargs.pop("rope_theta", 10_000), hidden_size_multiplier=1.3, hidden_size_multiple_of=4096, @@ -421,9 +461,9 @@ def llama3_8B(cls, vocab_size: int, **kwargs) -> "TransformerConfig": return cls.llama_like( d_model=4096, vocab_size=vocab_size, - n_layers=32, - n_heads=32, - n_kv_heads=8, + n_layers=kwargs.pop("n_layers", 32), + n_heads=kwargs.pop("n_heads", 32), + n_kv_heads=kwargs.pop("n_kv_heads", 8), rope_theta=kwargs.pop("rope_theta", 500_000), hidden_size_multiplier=1.3, hidden_size_multiple_of=1024, @@ -438,9 +478,9 @@ def llama3_70B(cls, vocab_size: int, **kwargs) -> "TransformerConfig": return cls.llama_like( d_model=8196, vocab_size=vocab_size, - n_layers=80, - n_heads=64, - n_kv_heads=8, + n_layers=kwargs.pop("n_layers", 80), + n_heads=kwargs.pop("n_heads", 64), + n_kv_heads=kwargs.pop("n_kv_heads", 8), rope_theta=kwargs.pop("rope_theta", 500_000), hidden_size_multiplier=1.3, hidden_size_multiple_of=4096, @@ -459,9 +499,9 @@ def llama3_405B( return cls.llama_like( d_model=16384, vocab_size=vocab_size, - n_layers=126, - n_heads=128, - n_kv_heads=8, + n_layers=kwargs.pop("n_layers", 126), + n_heads=kwargs.pop("n_heads", 128), + n_kv_heads=kwargs.pop("n_kv_heads", 8), rope_theta=kwargs.pop("rope_theta", 500_000), hidden_size_multiplier=1.2, hidden_size_multiple_of=4096, @@ -658,12 +698,20 @@ def init_weights( ) # Feed-forward weights. - self.init_method.init_feed_forward( - block.feed_forward, - block_idx=block.block_idx, - num_blocks=len(self.blocks), - generator=generator, - ) + if hasattr(block, "feed_forward"): + self.init_method.init_feed_forward( + block.feed_forward, + block_idx=block.block_idx, + num_blocks=len(self.blocks), + generator=generator, + ) + else: + self.init_method.init_feed_forward_moe( + block.feed_forward_moe, + block_idx=block.block_idx, + num_blocks=len(self.blocks), + generator=generator, + ) # Warm up RoPE cache. if max_seq_len is not None and att.rope is not None: @@ -793,6 +841,7 @@ def apply_fsdp( param_dtype: Optional[torch.dtype] = None, reduce_dtype: torch.dtype = torch.float32, pp_enabled: bool = False, + wrapping_strategy: TransformerDataParallelWrappingStrategy = TransformerDataParallelWrappingStrategy.full, ): """ Apply FSDP(2) to the model. @@ -805,6 +854,7 @@ def apply_fsdp( :param param_dtype: The data type to materialize params in. Defaults to the current param dtype. :param reduce_dtype: The data type for gradient reduction. :pp_enabled: If pipeline parallelism is also enabled. + :wrapping_strategy: The wrapping strategy. """ # Adapted from # https://github.com/pytorch/torchtitan/blob/90c889e972b56b9faadebbb78fc985dedc537ed9/torchtitan/parallelisms/parallelize_llama.py#L289 @@ -817,16 +867,36 @@ def apply_fsdp( fsdp_config = dict(mesh=dp_mesh, mp_policy=mp_policy) for block_id, block in enumerate(self.blocks): + reshard_after_forward = True if pp_enabled: # For PP, do not reshard after forward to avoid per-microbatch # all-gathers, which can be expensive and non-overlapped reshard_after_forward = False - else: + elif wrapping_strategy == TransformerDataParallelWrappingStrategy.full: # As an optimization, do not reshard after forward for the last # transformer block since FSDP would prefetch it immediately reshard_after_forward = int(block_id) < len(self.blocks) - 1 + + if wrapping_strategy == TransformerDataParallelWrappingStrategy.fine_grained: + if hasattr(block, "feed_forward"): + fully_shard( + block.feed_forward, + reshard_after_forward=reshard_after_forward, + **fsdp_config, + ) + else: + fully_shard( + block.feed_forward_moe, + reshard_after_forward=reshard_after_forward, + **fsdp_config, + ) + fully_shard(block, reshard_after_forward=reshard_after_forward, **fsdp_config) + if wrapping_strategy == TransformerDataParallelWrappingStrategy.fine_grained: + fully_shard(self.embeddings, reshard_after_forward=not pp_enabled, **fsdp_config) + fully_shard(self.w_out, reshard_after_forward=False, **fsdp_config) + fully_shard(self, reshard_after_forward=not pp_enabled, **fsdp_config) if dp_mesh is None: diff --git a/src/olmo_core/train/callbacks/__init__.py b/src/olmo_core/train/callbacks/__init__.py index 791b9948..44028505 100644 --- a/src/olmo_core/train/callbacks/__init__.py +++ b/src/olmo_core/train/callbacks/__init__.py @@ -8,6 +8,7 @@ from .garbage_collector import GarbageCollectorCallback from .gpu_memory_monitor import GPUMemoryMonitorCallback from .grad_clipper import GradClipperCallback +from .moe_handler import MoEHandlerCallback from .profiler import ProfilerCallback from .scheduler import SchedulerCallback from .sequence_length_scheduler import SequenceLengthSchedulerCallback @@ -26,6 +27,7 @@ "EvaluatorCallback", "Float8HandlerCallback", "LMEvaluatorCallbackConfig", + "MoEHandlerCallback", "GarbageCollectorCallback", "GPUMemoryMonitorCallback", "GradClipperCallback", diff --git a/src/olmo_core/train/callbacks/callback.py b/src/olmo_core/train/callbacks/callback.py index 546c8bd3..2c00c60f 100644 --- a/src/olmo_core/train/callbacks/callback.py +++ b/src/olmo_core/train/callbacks/callback.py @@ -2,6 +2,8 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, ClassVar, Dict +import torch + from olmo_core.aliases import PathOrStr from olmo_core.config import Config @@ -86,6 +88,23 @@ def pre_step(self, batch: Dict[str, Any]): """ del batch + def pre_backward( + self, + *, + batch: Dict[str, Any], + micro_batch: Dict[str, Any], + loss: torch.Tensor, + ): + """ + Runs right before the backward pass on a micro-batch. This can be used to modify the + ``loss`` before ``loss.backward()`` is called. + + :param batch: The full batch. + :param micro_batch: The micro-batch just used. + :param loss: The combined loss from the micro-batch (``ce_loss`` plus the optional ``z_loss``). + """ + del batch, micro_batch, loss + def pre_optim_step(self): """ Runs right after the forward-backward passes, right before the optimizer step. @@ -98,6 +117,20 @@ def post_train_batch(self): """ pass + def pre_eval_batch(self, batch: Dict[str, Any]): + """ + Runs right before an eval batch is processed with :meth:`~olmo_core.train.Trainer.eval_batch()`. + + :param batch: The eval batch. + """ + del batch + + def post_eval_batch(self): + """ + Runs after after an eval batch is processed with :meth:`~olmo_core.train.Trainer.eval_batch()`. + """ + pass + def post_step(self): """ Runs after a complete step (potentially including evals and checkpointing). diff --git a/src/olmo_core/train/callbacks/evaluator_callback.py b/src/olmo_core/train/callbacks/evaluator_callback.py index 08148872..6f486f87 100644 --- a/src/olmo_core/train/callbacks/evaluator_callback.py +++ b/src/olmo_core/train/callbacks/evaluator_callback.py @@ -2,8 +2,6 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING, List, Optional -import torch - from olmo_core.data import NumpyDatasetConfig, NumpyPaddedFSLDataset from olmo_core.distributed.utils import get_world_size from olmo_core.eval import Evaluator @@ -64,10 +62,9 @@ def post_step(self): eval_step += 1 eval_tokens += batch["input_ids"].numel() * dp_world_size batch = move_to_device(batch, self.trainer.device) - with torch.no_grad(): - ce_loss, _, logits = self.trainer._model_forward( - batch, loss_reduction="none", compute_z_loss=False - ) + logits, ce_loss, _ = self.trainer.eval_batch( + batch, loss_reduction="none", compute_z_loss=False + ) evaluator.update_metrics(batch, ce_loss, logits) if eval_step % self.trainer.cancel_check_interval == 0: diff --git a/src/olmo_core/train/callbacks/moe_handler.py b/src/olmo_core/train/callbacks/moe_handler.py new file mode 100644 index 00000000..de6c465d --- /dev/null +++ b/src/olmo_core/train/callbacks/moe_handler.py @@ -0,0 +1,86 @@ +from dataclasses import dataclass +from typing import Any, Dict, Optional + +import torch + +from olmo_core.distributed.utils import get_local_tensor +from olmo_core.exceptions import OLMoConfigurationError +from olmo_core.nn.moe import MoE +from olmo_core.utils import move_to_device + +from .callback import Callback + + +@dataclass +class MoEHandlerCallback(Callback): + """ + A callback to be used in conjunction with :class:`~olmo_core.nn.moe.MoE` based models for + including the MoE's internal losses in the training loss. + """ + + _batch_lb_loss = None + _batch_z_loss = None + _moe_layer = None + + def clear_loss_buffers(self): + assert self._moe_layer is not None + self._moe_layer.clear_losses() + if self._batch_lb_loss is not None: + self._batch_lb_loss.zero_() + if self._batch_z_loss is not None: + self._batch_z_loss.zero_() + + def pre_train(self): + for module in self.trainer.model.modules(): + if isinstance(module, MoE): + self._moe_layer = module # only need one + break + else: + raise OLMoConfigurationError( + f"No MoE layer found in model, required by {self.__class__.__name__}" + ) + + def pre_step(self, batch: Dict[str, Any]): + del batch + self.clear_loss_buffers() + + def post_eval_batch(self): + self.clear_loss_buffers() + + def pre_backward( + self, + *, + batch: Dict[str, Any], + micro_batch: Dict[str, Any], + loss: torch.Tensor, + ): + assert self._moe_layer is not None + + scale_factor = micro_batch["input_ids"].shape[0] / batch["input_ids"].shape[0] + + moe_loss: Optional[torch.Tensor] = None + if (lb_loss := self._moe_layer.get_load_balancing_loss()) is not None: + lb_loss.mul_(scale_factor) + moe_loss = lb_loss + if self._batch_lb_loss is None: + self._batch_lb_loss = move_to_device(torch.tensor(0.0), lb_loss.device) + self._batch_lb_loss += get_local_tensor(lb_loss) + + if (rz_loss := self._moe_layer.get_router_z_loss()) is not None: + rz_loss.mul_(scale_factor) + if moe_loss is not None: + moe_loss += rz_loss + else: + moe_loss = rz_loss + if self._batch_z_loss is None: + self._batch_z_loss = move_to_device(torch.tensor(0.0), rz_loss.device) + self._batch_z_loss += get_local_tensor(rz_loss) + + if moe_loss is not None: + loss += moe_loss + + def post_train_batch(self): + if self._batch_lb_loss is not None: + self.trainer.record_metric("train/load balancing loss", self._batch_lb_loss) + if self._batch_z_loss is not None: + self.trainer.record_metric("train/router Z loss", self._batch_z_loss) diff --git a/src/olmo_core/train/trainer.py b/src/olmo_core/train/trainer.py index ff4392b0..7b960c5c 100644 --- a/src/olmo_core/train/trainer.py +++ b/src/olmo_core/train/trainer.py @@ -836,6 +836,98 @@ def add_callback(self, name: str, callback: Callback): self._sort_callbacks() callback.post_attach() + def model_forward(self, micro_batch: Dict[str, Any]) -> torch.Tensor: + """ + Run a forward pass on a micro-batch, returning the logits. + """ + with self._model_forward_context(): + # shape: (batch_size, seq_len, vocab_size) + logits = self.model( + input_ids=micro_batch["input_ids"], + # attention_mask=micro_batch.get("attention_mask"), + # attention_bias=micro_batch.get("attention_bias"), + doc_lens=micro_batch.get("doc_lens"), + max_doc_lens=micro_batch.get("max_doc_lens"), + ) + return logits + + def get_losses( + self, + micro_batch: Dict[str, Any], + logits: torch.Tensor, + loss_reduction: Literal["mean", "sum", "none"] = "mean", + compute_z_loss: Optional[bool] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Compute the cross-entropy loss and optionally the Z-loss from a micro-batch and the + corresponding logits returned from :meth:`model_forward()`. + + :param micro_batch: The micro-batch to evaluate. + :param logits: The logits from the forward pass. + :param loss_reduction: The (local) reduction to apply to the loss(es). + :param compute_z_loss: Whether or not to compute and return the Z-loss. + + :returns: The cross entropy and optional Z-loss, respectively. + """ + loss_fn = cross_entropy_loss if not self.fused_loss else fused_cross_entropy_loss + if compute_z_loss is None: + compute_z_loss = self.z_loss_multiplier is not None + + # shape: (batch_size, seq_len - 1, vocab_size) + logits_for_loss = logits[..., :-1, :].contiguous() + # shape: (batch_size * (seq_len - 1), vocab_size) + logits_for_loss = logits_for_loss.view(-1, logits_for_loss.size(-1)) + + # shape: (batch_size, seq_len - 1) + labels = micro_batch.get("labels", self._get_labels(micro_batch)) + # shape: (batch_size * (seq_len - 1),) + labels = labels.view(-1) + + ce_loss, z_loss = loss_fn( + logits_for_loss, + labels, + ignore_index=self.data_loader.collator.label_ignore_index, + reduction=loss_reduction, + compute_z_loss=compute_z_loss, + z_loss_multiplier=self.z_loss_multiplier or 1e-4, + ) + + if loss_reduction == "none": + # Reshape (batch_size * (seq_len - 1),) -> (batch_size, seq_len - 1) + ce_loss = ce_loss.view(micro_batch["input_ids"].shape[0], -1) + if z_loss is not None: + z_loss = z_loss.view(micro_batch["input_ids"].shape[0], -1) + + return ce_loss, z_loss + + def eval_batch( + self, + batch: Dict[str, Any], + loss_reduction: Literal["mean", "sum", "none"] = "mean", + compute_z_loss: Optional[bool] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Get the loss for an eval batch. + + .. important:: + You are responsible for ensuring the model is in ``.eval()`` mode before calling this. + + :param batch: The batch to evaluate. + :param loss_reduction: The (local) reduction to apply to the loss(es). + :param compute_z_loss: Whether or not to compute and return the Z-loss. + + :returns: The logits, cross-entropy loss, and Z-loss, respectively. + """ + batch = move_to_device(batch, self.device) + for callback in self.callbacks.values(): + callback.pre_eval_batch(batch) + with torch.no_grad(): + logits = self.model_forward(batch) + ce_loss, z_loss = self.get_losses( + batch, logits, loss_reduction=loss_reduction, compute_z_loss=compute_z_loss + ) + return logits, ce_loss, z_loss + def _sort_callbacks(self): self.callbacks = OrderedDict( ( @@ -955,69 +1047,13 @@ def _model_forward( loss_reduction: Literal["mean", "sum", "none"] = "mean", compute_z_loss: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]: - with self._model_forward_context(): - # shape: (batch_size, seq_len, vocab_size) - logits = self.model( - input_ids=batch["input_ids"], - # attention_mask=batch.get("attention_mask"), - # attention_bias=batch.get("attention_bias"), - doc_lens=batch.get("doc_lens"), - max_doc_lens=batch.get("max_doc_lens"), - ) - - # shape: (batch_size, seq_len - 1, vocab_size) - logits_for_loss = logits[..., :-1, :].contiguous() - # shape: (batch_size * (seq_len - 1), vocab_size) - logits_for_loss = logits_for_loss.view(-1, logits_for_loss.size(-1)) - # shape: (batch_size, seq_len - 1) - labels = batch.get("labels", self._get_labels(batch)) - # shape: (batch_size * (seq_len - 1),) - labels = labels.view(-1) - - loss_fn = cross_entropy_loss if not self.fused_loss else fused_cross_entropy_loss - ce_loss, z_loss = loss_fn( - logits_for_loss, - labels, - ignore_index=self.data_loader.collator.label_ignore_index, - reduction=loss_reduction, - compute_z_loss=compute_z_loss, - z_loss_multiplier=self.z_loss_multiplier or 1e-4, + # NOTE: keep this method for backwards compatibility. + logits = self.model_forward(batch) + ce_loss, z_loss = self.get_losses( + batch, logits, loss_reduction=loss_reduction, compute_z_loss=compute_z_loss ) - - if loss_reduction == "none": - # Reshape (batch_size * (seq_len - 1),) -> (batch_size, seq_len - 1) - ce_loss = ce_loss.view(batch["input_ids"].shape[0], -1) - if z_loss is not None: - z_loss = z_loss.view(batch["input_ids"].shape[0], -1) - return ce_loss, z_loss, logits - def _get_microbatch_loss( - self, micro_batch: Dict[str, Any], batch_num_tokens_for_loss: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - # NOTE: we use the "sum" loss reduction and then divide by 'batch_num_tokens_for_loss' - # (the total number of tokens used in the loss across the whole batch, not just the micro batch) - # to avoid biasing the loss in the case where micro-batches might not be the same size. - ce_loss, z_loss, logits = self._model_forward( - micro_batch, compute_z_loss=self.z_loss_multiplier is not None, loss_reduction="sum" - ) - ce_loss = ce_loss / batch_num_tokens_for_loss - - # In case this helps with memory utilization. - del micro_batch - - # Get loss to optimize for. - if self.z_loss_multiplier is not None: - assert z_loss is not None - z_loss = z_loss / batch_num_tokens_for_loss - loss = ce_loss + z_loss - else: - loss = ce_loss - - del logits - - return loss, ce_loss, z_loss - @contextlib.contextmanager def _train_microbatch_context( self, micro_batch_idx: int, num_micro_batches: int @@ -1054,9 +1090,6 @@ def _train_batch(self, batch: Dict[str, Any], dry_run: bool = False): micro_batches = split_batch(batch, self.rank_microbatch_size // seq_len) num_micro_batches = len(micro_batches) - # In case this helps with memory utilization. - del batch - ce_batch_loss = move_to_device(torch.tensor(0.0), self.device) z_batch_loss = ( None @@ -1068,9 +1101,22 @@ def _train_batch(self, batch: Dict[str, Any], dry_run: bool = False): for micro_batch_idx, micro_batch in enumerate(micro_batches): with self._train_microbatch_context(micro_batch_idx, num_micro_batches): # Run forward pass. - loss, ce_loss, z_loss = self._get_microbatch_loss( - micro_batch, batch_num_tokens_for_loss - ) + logits = self.model_forward(micro_batch) + + # NOTE: we use the "sum" loss reduction and then divide by 'batch_num_tokens_for_loss' + # (the total number of tokens used in the loss across the whole batch, not just the micro batch) + # to avoid biasing the loss in the case where micro-batches might not be the same size. + ce_loss, z_loss = self.get_losses(micro_batch, logits, loss_reduction="sum") + ce_loss.div_(batch_num_tokens_for_loss) + if z_loss is not None: + z_loss.div_(batch_num_tokens_for_loss) + + # Get loss to optimize for. + loss: torch.Tensor + if z_loss is not None: + loss = ce_loss + z_loss + else: + loss = ce_loss # Update overall CE batch loss. ce_batch_loss += get_local_tensor(ce_loss.detach()) @@ -1080,9 +1126,16 @@ def _train_batch(self, batch: Dict[str, Any], dry_run: bool = False): assert z_batch_loss is not None z_batch_loss += get_local_tensor(z_loss.detach()) + # Run through callbacks. + for callback in self.callbacks.values(): + callback.pre_backward(batch=batch, micro_batch=micro_batch, loss=loss) + # Run backward pass. loss.backward() + # In case this helps with memory utilization. + del batch + if dry_run: # Zero-gradients again. self.optim.zero_grad(set_to_none=True) @@ -1104,6 +1157,10 @@ def _train_batch(self, batch: Dict[str, Any], dry_run: bool = False): if isinstance(self.optim, SkipStepOptimizer): self.record_metric(OPTIM_STEP_SKIPPED_METRIC, self.optim.step_skipped) + # Run through callbacks. + for callback in self.callbacks.values(): + callback.post_train_batch() + def _iter_batches(self) -> Generator[Dict[str, Any], None, None]: data_iterator = iter(self.data_loader) @@ -1164,9 +1221,6 @@ def _fit_epoch(self): self._train_batch(batch) - for callback in self.callbacks.values(): - callback.post_train_batch() - for callback in self.callbacks.values(): callback.post_step() diff --git a/src/scripts/train/OLMo-13B.py b/src/scripts/train/OLMo-13B.py index 35551266..0098941a 100644 --- a/src/scripts/train/OLMo-13B.py +++ b/src/scripts/train/OLMo-13B.py @@ -5,9 +5,9 @@ import logging from olmo_core.config import DType -from olmo_core.distributed.parallel import DataParallelConfig, DataParallelType +from olmo_core.distributed.parallel import DataParallelType from olmo_core.internal.experiment import CommonComponents, main -from olmo_core.nn.transformer import TransformerConfig +from olmo_core.nn.transformer import TransformerConfig, TransformerDataParallelConfig from olmo_core.optim import AdamWConfig, OptimGroupOverride from olmo_core.train import TrainerConfig from olmo_core.train.callbacks import CheckpointerCallback, CometCallback, WandBCallback @@ -19,7 +19,7 @@ def build_model_config(common: CommonComponents) -> TransformerConfig: return TransformerConfig.olmo_13B( vocab_size=common.tokenizer.padded_vocab_size(), compile=True, - dp_config=DataParallelConfig( + dp_config=TransformerDataParallelConfig( name=DataParallelType.fsdp, param_dtype=DType.bfloat16, reduce_dtype=DType.float32 ), ) diff --git a/src/scripts/train/OLMo-1B.py b/src/scripts/train/OLMo-1B.py index 9c768ce9..28a92c10 100644 --- a/src/scripts/train/OLMo-1B.py +++ b/src/scripts/train/OLMo-1B.py @@ -2,24 +2,20 @@ Train a 1B OLMo model. Run this script without any arguments to see usage info. """ -import logging - from olmo_core.config import DType -from olmo_core.distributed.parallel import DataParallelConfig, DataParallelType +from olmo_core.distributed.parallel import DataParallelType from olmo_core.internal.experiment import CommonComponents, main -from olmo_core.nn.transformer import TransformerConfig +from olmo_core.nn.transformer import TransformerConfig, TransformerDataParallelConfig from olmo_core.optim import AdamWConfig, OptimGroupOverride from olmo_core.train import TrainerConfig from olmo_core.train.callbacks import CheckpointerCallback, CometCallback, WandBCallback -log = logging.getLogger(__name__) - def build_model_config(common: CommonComponents) -> TransformerConfig: return TransformerConfig.olmo_1B( vocab_size=common.tokenizer.padded_vocab_size(), compile=True, - dp_config=DataParallelConfig( + dp_config=TransformerDataParallelConfig( name=DataParallelType.fsdp, param_dtype=DType.bfloat16, reduce_dtype=DType.float32 ), ) diff --git a/src/scripts/train/OLMo-7B.py b/src/scripts/train/OLMo-7B.py index 2533f5a2..6ecd91a6 100644 --- a/src/scripts/train/OLMo-7B.py +++ b/src/scripts/train/OLMo-7B.py @@ -5,9 +5,9 @@ import logging from olmo_core.config import DType -from olmo_core.distributed.parallel import DataParallelConfig, DataParallelType +from olmo_core.distributed.parallel import DataParallelType from olmo_core.internal.experiment import CommonComponents, main -from olmo_core.nn.transformer import TransformerConfig +from olmo_core.nn.transformer import TransformerConfig, TransformerDataParallelConfig from olmo_core.optim import AdamWConfig, OptimGroupOverride from olmo_core.train import TrainerConfig from olmo_core.train.callbacks import CheckpointerCallback, CometCallback, WandBCallback @@ -19,7 +19,7 @@ def build_model_config(common: CommonComponents) -> TransformerConfig: return TransformerConfig.olmo_7B( vocab_size=common.tokenizer.padded_vocab_size(), compile=True, - dp_config=DataParallelConfig( + dp_config=TransformerDataParallelConfig( name=DataParallelType.fsdp, param_dtype=DType.bfloat16, reduce_dtype=DType.float32 ), ) diff --git a/src/scripts/train/OLMoE-1B-7B.py b/src/scripts/train/OLMoE-1B-7B.py new file mode 100644 index 00000000..8d6dd510 --- /dev/null +++ b/src/scripts/train/OLMoE-1B-7B.py @@ -0,0 +1,122 @@ +""" +Train a 1B-7B OLMoE model (mixture of experts). +Run this script without any arguments to see usage info. +""" + +from olmo_core.config import DType +from olmo_core.distributed.parallel import DataParallelType +from olmo_core.internal.experiment import CommonComponents, main +from olmo_core.nn.moe import MoEActivationFn, MoEConfig, MoEMLPImplementation, MoEType +from olmo_core.nn.transformer import ( + TransformerBlockType, + TransformerConfig, + TransformerDataParallelConfig, + TransformerDataParallelWrappingStrategy, +) +from olmo_core.optim import AdamWConfig, OptimGroupOverride +from olmo_core.train import TrainerConfig +from olmo_core.train.callbacks import ( + CheckpointerCallback, + CometCallback, + MoEHandlerCallback, + WandBCallback, +) + + +def build_model_config(common: CommonComponents) -> TransformerConfig: + model_config = TransformerConfig.olmo_1B( + vocab_size=common.tokenizer.padded_vocab_size(), + n_layers=16, + n_heads=16, + compile=True, + fused_ops=False, + block_name=TransformerBlockType.moe_reordered_norm, + dp_config=TransformerDataParallelConfig( + name=DataParallelType.fsdp, + param_dtype=DType.bfloat16, + reduce_dtype=DType.float32, + wrapping_strategy=TransformerDataParallelWrappingStrategy.full, + ), + ) + model_config.block.feed_forward = None + model_config.block.feed_forward_moe = MoEConfig( + name=MoEType.dropless, + hidden_size=int(0.5 * model_config.d_model), + activation_fn=MoEActivationFn.swiglu, + mlp_implementation=MoEMLPImplementation.grouped, + num_experts=64, + top_k=8, + num_layers=model_config.n_layers, + zloss_weight=0.001, + loss_weight=0.01, + bias=False, + dtype=model_config.dtype, + ) + return model_config + + +def build_optim_config(common: CommonComponents) -> AdamWConfig: + del common + return AdamWConfig( + lr=4e-4, + weight_decay=0.1, + betas=(0.9, 0.95), + group_overrides=[ + OptimGroupOverride(params=["embeddings.weight"], opts=dict(weight_decay=0.0)) + ], + fused=True, + ) + + +def build_trainer_config(common: CommonComponents) -> TrainerConfig: + return ( + TrainerConfig( + save_folder=common.save_folder, + rank_microbatch_size=2 * 4096, + save_overwrite=True, + metrics_collect_interval=10, + cancel_check_interval=1, + z_loss_multiplier=1e-5, + ) + .with_callback( + "checkpointer", + CheckpointerCallback( + save_interval=10_000, + ephemeral_save_interval=1000, + save_async=True, + ), + ) + .with_callback( + "moe", + MoEHandlerCallback(), + ) + .with_callback( + "comet", + CometCallback( + name=common.run_name, + workspace="ai2", + project="OLMo-core-1B", + enabled=True, + cancel_check_interval=10, + ), + ) + .with_callback( + "wandb", + WandBCallback( + name=common.run_name, + entity="ai2-llm", + project="OLMo-core-1B", + enabled=False, + cancel_check_interval=10, + ), + ) + ) + + +if __name__ == "__main__": + main( + global_batch_size=1024 * 4096, + model_config_builder=build_model_config, + optim_config_builder=build_optim_config, + trainer_config_builder=build_trainer_config, + ) diff --git a/src/test/nn/moe_test.py b/src/test/nn/moe_test.py new file mode 100644 index 00000000..4603d3e9 --- /dev/null +++ b/src/test/nn/moe_test.py @@ -0,0 +1,44 @@ +import pytest +import torch + +from olmo_core.config import DType +from olmo_core.nn.moe import MoEConfig, MoEMLPImplementation, MoEType + +from ..utils import requires_gpu, requires_megablocks + + +@requires_gpu +@requires_megablocks +@pytest.mark.parametrize("moe_type", [MoEType.default, MoEType.dropless]) +@pytest.mark.parametrize("mlp_impl", [MoEMLPImplementation.sparse, MoEMLPImplementation.grouped]) +@pytest.mark.parametrize("dtype", [pytest.param(torch.bfloat16, id="BF16")]) +def test_moe(moe_type, mlp_impl, dtype): + d_model = 128 + config = MoEConfig( + name=moe_type, + mlp_implementation=mlp_impl, + hidden_size=512, + num_experts=4, + dtype=DType.from_pt(dtype), + ) + moe = config.build(d_model=d_model, init_device="cuda") + + # Check num params calculation. + num_params = 0 + for p in moe.parameters(): + num_params += p.numel() + if config.num_params(d_model) != num_params: + # For debugging... + for n, p in moe.named_parameters(): + print(f"{n}: {p.shape}") + assert config.num_params(d_model) == num_params + + # Run forward pass. + x = torch.randn(2, 16, d_model, dtype=dtype, device="cuda", requires_grad=True) + output = moe(x) + assert output.shape == x.shape + loss = output.sum() + moe.get_loss() + + # Run backward pass. + loss.backward() + assert x.grad is not None diff --git a/src/test/utils.py b/src/test/utils.py index 8128b001..4b04415e 100644 --- a/src/test/utils.py +++ b/src/test/utils.py @@ -3,6 +3,7 @@ has_cuda = torch.cuda.is_available() has_flash_attn = False +has_megablocks = False try: import flash_attn # type: ignore @@ -12,6 +13,14 @@ except ModuleNotFoundError: pass +try: + import megablocks # type: ignore + + has_megablocks = True + del megablocks +except ModuleNotFoundError: + pass + GPU_MARKS = (pytest.mark.gpu, pytest.mark.skipif(not has_cuda, reason="Requires a GPU")) @@ -34,6 +43,18 @@ def requires_flash_attn(func): return func +MEGABLOCKS_MARKS = ( + pytest.mark.gpu, + pytest.mark.skipif(not has_megablocks, reason="Requires megablocks"), +) + + +def requires_megablocks(func): + for mark in MEGABLOCKS_MARKS: + func = mark(func) + return func + + INIT_DEVICES = [ pytest.param(torch.device("meta"), id="device=meta"), pytest.param(torch.device("cpu"), id="device=CPU"),