diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index f963c7d2f..ab3089822 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -7,6 +7,7 @@ on: pull_request: branches: - main + - Torch2 push: branches: - main @@ -108,7 +109,7 @@ jobs: timeout-minutes: 15 env: BEAKER_TOKEN: ${{ secrets.BEAKER_TOKEN }} - BEAKER_IMAGE: dolma-test + BEAKER_IMAGE: dolma-torch2-test BEAKER_WORKSPACE: ai2/llm-testing steps: - name: Determine current commit SHA (pull request) @@ -133,7 +134,7 @@ jobs: image: beaker: ${{ env.BEAKER_IMAGE }} context: - priority: preemptible + priority: normal resources: gpuCount: 1 constraints: @@ -141,11 +142,6 @@ jobs: - ai2/general-cirrascale - ai2/general-cirrascale-a100-80g-ib - ai2/allennlp-cirrascale - - ai2/aristo-cirrascale - - ai2/mosaic-cirrascale - - ai2/mosaic-cirrascale-a100 - - ai2/prior-cirrascale - - ai2/s2-cirrascale envVars: - name: COMMIT_SHA value: ${{ env.COMMIT_SHA }} @@ -153,6 +149,8 @@ jobs: value: ${{ secrets.GITHUB_TOKEN }} - name: CUDA_LAUNCH_BLOCKING value: "1" + - name: CUBLAS_WORKSPACE_CONFIG + value: ":16:8" - name: TOKENIZERS_PARALLELISM value: "false" command: ["/entrypoint.sh", "pytest", "-v", "-m", "gpu", "tests/"] diff --git a/LOG.md b/LOG.md index febdaaa9e..84ab52b81 100644 --- a/LOG.md +++ b/LOG.md @@ -1,6 +1,47 @@ Experiment Log ============== +2023-04-03 +---------- + +We added the option to decouple the MLP and Attention computations as in the PaLM architecture. +That is, within each transformer block we compute `MLP(LN(x)) + Attention(LN(x))` instead of `MLP(LN(x + Attention(LN(x))))` (ignoring some skip connections). +This allows to increase throughput because we can fuse the separate feed-forward and attention input projections into a single linear layer. +We also experimented with [fusing the output projections](https://github.com/allenai/LLM/pull/79) into a single linear layer but that didn't help, possibly due to the overhead of concatenating the feed-forward and attention activations together. + +2023-03-28 +---------- + +We've investigated a number ways to optimize training throughput in terms of tokens per second and MFU (model flop utilization). This is a list of all of the optimizations that have worked so far, ranked by how much of speedup they gave on a 1.2b param model: + +1. Using FlashAttention via PyTorch's built-in `scaled_dot_product_attention` function. This resulted in a ~12% speedup over the default attention implementation while also reducing GPU memory utilization. + + Unfortunately ALiBi can't be used with FlashAttention at the moment, so the best option if we want to use relative positional encodings is probably RoPE (which can be used with FlashAttention). In general RoPE is slower than ALiBi but when combined with FlashAttention it's faster. Of course ALiBi + FlashAttention would be ideal. + +1. Setting embedding/vocab size to a multiple of 128. E.g. the actual vocab size is 50257, but we force the embedding size to be 50304. This resulted in an ~11% speedup. +1. Using low-precision LayerNorm when **not** using `torch.compile()`. This resulted in a speedup of ~10%, but it actually slows throughput when using a compiled model. This probably has to do with manually casting tensors to different data types, which cause more breaks in the graph. +1. Compiling the model via `torch.compile()` with the default mode. This resulted in a ~7% speedup without increasing (and in some cases decreasing) GPU memory utilization. + + The other compile modes ("reduce-overhead" and "max-autotune") were not as fast and required substantially more GPU memory. + + Compiling as a "fullgraph" also improves throughput even further except when using FSDP since FSDP forces breaks in the graph. +1. Tweaking the FSDP settings to use "PURE" mixed precision, limit all gathers, and use non-reentrant activation checkpointing resulted in a 1-2% speedup. + +Using the best compatible combination of the above settings (so everything except #3) gets us close to 60% MFU with the 1.2b model. That's really good! + +For more details, see: +- [Benchmarking the performance of PyTorch's new `compile()` and built-in FlashAttention.](https://wandb.ai/ai2-llm/petew-torch2-benchmarks/reports/PyTorch-2-0-benchmarks--VmlldzozODQyMDY5?accessToken=2fh801xe265n5xx7juphb1xnx8itvls8g7nrqsjdd4ja0xlks7kaozue94z2mez3) +- [Benchmarking the cost of using RoPE](https://wandb.ai/ai2-llm/rope-benchmarks/reports/Benchmarking-RoPE--VmlldzozODQ1MjMz) +- [Benchmarking the performance of `compile()` with FSDP](https://wandb.ai/ai2-llm/fsdp-compile-benchmarks) +- [Benchmarking low precision LayerNorm](https://api.wandb.ai/links/ai2-llm/9favfpnh) + + +2023-03-15 +---------- + +The cluster is down for maintenance, so we're just queueing up some features we want to run. We also used the LUMI downtime to build a better logging feature. When running 1000s of nodes in a cluster, it's difficult to get logs that make sense. We're sending our logs to third-party logging provider [logz.io](https://logz.io). It's basic, but it gets the job done. + + 2023-03-14 ---------- @@ -16,8 +57,3 @@ Findings: I'm not sure what that buys us, and it's one extra component in the mix, so I didn't do it that way. * Automatic restarts work. One run got killed and automatically restarted. It is great that restarts work, but somewhat worrisome that we're already sampling this behavior after less than 45 minutes of runtime on only one node. - -2023-03-15 ----------- - -The cluster is down for maintenance, so we're just queueing up some features we want to run. We also used the LUMI downtime to build a better logging feature. When running 1000s of nodes in a cluster, it's difficult to get logs that make sense. We're sending our logs to third-party logging provider [logz.io](https://logz.io). It's basic, but it gets the job done. diff --git a/Makefile b/Makefile index c299a43b7..3ac3be0f5 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,7 @@ # If you update this, also update BEAKER_IMAGE in .github/workflows/main.yml -IMAGE_NAME_BASE = dolma +IMAGE_NAME_BASE = dolma-torch2 # If you update this, also update BEAKER_WORKSPACE in .github/workflows/main.yml -BEAKER_WORKSPACE = "ai2/llm-testing" +BEAKER_WORKSPACE = ai2/llm-testing BEAKER_USER = $(shell beaker account whoami --format=json | jq -r '.[0].name') GANTRY_IMAGE = $(shell beaker workspace images $(BEAKER_WORKSPACE) --format=json | jq -r -c '.[] | select( .name == "$(IMAGE_NAME_BASE)-gantry" ) | .fullName') @@ -58,7 +58,7 @@ show-beaker-workspace : gantry-test : gantry run \ --workspace "$(BEAKER_WORKSPACE)" \ - --priority "preemptible" \ + --priority "normal" \ --beaker-image "$(GANTRY_IMAGE)" \ --gpus 1 \ --description "Test run" \ diff --git a/README.md b/README.md index 461bec08b..33e44b371 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ gantry run \ --nfs \ --priority preemptible \ --gpus 8 \ - --beaker-image dolma-gantry \ + --beaker-image dolma-torch2-gantry \ --cluster 'ai2/*-cirrascale' \ --allow-dirty \ -- composer scripts/train.py configs/1.2b-c4.yaml @@ -36,7 +36,7 @@ Train the 70B model on c4 with gantry across multiple nodes: gantry run \ --workspace ai2/llm-testing \ --priority "high" \ - --beaker-image dolma-gantry \ + --beaker-image dolma-torch2-gantry \ --cluster ai2/general-cirrascale-a100-80g-ib \ --gpus 8 \ --nfs \ diff --git a/configs/1.2b-c4.yaml b/configs/1.2b-c4.yaml index 21c44dbd9..f93f4cb70 100644 --- a/configs/1.2b-c4.yaml +++ b/configs/1.2b-c4.yaml @@ -11,17 +11,24 @@ model: alibi_bias_max: 8.0 attention_dropout: 0.0 attention_layer_norm: true + layer_norm_type: default # if not compiling, use 'low_precision' + activation_type: swiglu residual_dropout: 0.0 embedding_dropout: 0.0 max_sequence_length: 1024 vocab_size: 50257 + embedding_size: 50304 eos_token_id: 50256 pad_token_id: 50256 - init_device: meta + init_device: null init_std: 0.02 +compile: + mode: default + fullgraph: null + optimizer: - name: decoupled_adamw + name: decoupled_lionw learning_rate: 2.0e-4 weight_decay: 1.2e-4 betas: @@ -71,13 +78,16 @@ device_eval_batch_size: null n_gpus: null -precision: null +precision: amp_bf16 fsdp_config: sharding_strategy: FULL_SHARD - mixed_precision: DEFAULT + mixed_precision: PURE activation_checkpointing: false activation_cpu_offload: false + activation_checkpointing_reentrant: false + limit_all_gathers: true + use_orig_params: true # needed to work with compile verbose: false speed_monitor: diff --git a/configs/300m-c4.yaml b/configs/300m-c4.yaml index 6f7934dba..be5c8f282 100644 --- a/configs/300m-c4.yaml +++ b/configs/300m-c4.yaml @@ -12,18 +12,24 @@ model: flash_attention: false attention_dropout: 0.0 attention_layer_norm: false + layer_norm_type: default # if not compiling, use 'low_precision' residual_dropout: 0.0 embedding_dropout: 0.0 max_sequence_length: 1024 include_bias: true vocab_size: 50257 + embedding_size: 50304 eos_token_id: 50256 pad_token_id: 50256 init_device: null init_std: 0.02 +compile: + mode: default + fullgraph: null + optimizer: - name: decoupled_adamw + name: decoupled_lionw learning_rate: 3.0e-4 weight_decay: 1.2e-4 betas: @@ -72,13 +78,16 @@ device_eval_batch_size: null n_gpus: null -precision: null +precision: amp_bf16 fsdp_config: sharding_strategy: FULL_SHARD - mixed_precision: DEFAULT + mixed_precision: PURE activation_checkpointing: false activation_cpu_offload: false + activation_checkpointing_reentrant: false + limit_all_gathers: true + use_orig_params: true # needed to work with compile verbose: false wandb: diff --git a/configs/70b-c4.yaml b/configs/70b-c4.yaml index a8aaacdb0..27a79e6fd 100644 --- a/configs/70b-c4.yaml +++ b/configs/70b-c4.yaml @@ -12,17 +12,25 @@ model: flash_attention: false attention_dropout: 0.0 # has to be 0 if using flash attn attention_layer_norm: true + block_type: parallel + layer_norm_type: default # if not compiling, use 'low_precision' + activation_type: swiglu residual_dropout: 0.0 embedding_dropout: 0.0 max_sequence_length: 2048 vocab_size: 50257 + embedding_size: 50304 eos_token_id: 50256 pad_token_id: 50256 init_device: meta init_std: 0.02 +compile: + mode: default + fullgraph: null + optimizer: - name: decoupled_adamw + name: decoupled_lionw learning_rate: 8.0e-5 weight_decay: 1.2e-4 betas: @@ -76,9 +84,12 @@ precision: amp_bf16 fsdp_config: sharding_strategy: FULL_SHARD - mixed_precision: DEFAULT # could be PURE with flash attn + mixed_precision: PURE activation_checkpointing: true activation_cpu_offload: false + activation_checkpointing_reentrant: false + limit_all_gathers: true + use_orig_params: true # needed to work with compile verbose: false speed_monitor: diff --git a/configs/tiny.yaml b/configs/tiny.yaml index b65f4b4b6..9bb0409f5 100644 --- a/configs/tiny.yaml +++ b/configs/tiny.yaml @@ -23,7 +23,7 @@ model: init_std: 0.02 optimizer: - name: decoupled_adamw + name: decoupled_lionw learning_rate: 3.0e-4 weight_decay: 1.2e-4 betas: @@ -72,7 +72,7 @@ device_eval_batch_size: null n_gpus: null -precision: null +precision: amp_bf16 fsdp_config: null diff --git a/docker/Dockerfile.base b/docker/Dockerfile.base index 83411da3f..c32eba965 100644 --- a/docker/Dockerfile.base +++ b/docker/Dockerfile.base @@ -1,15 +1,14 @@ # Defines a CUDA-enabled Docker image suitable for installing all dependencies # to this project. -FROM ghcr.io/allenai/pytorch:1.13.1-cuda11.7-python3.10 +FROM ghcr.io/allenai/pytorch:2.0.0-cuda11.8-python3.10 +# Install flash attn (and triton dependency) from our pre-built wheel. # We need cuda dev for the old version of triton. # NOTE: once we're able to upgrade triton to >=2.0, we can remove this. -RUN /opt/conda/bin/conda install -c nvidia cuda-libraries-dev - -# Install flash attn (and triton dependency) from our pre-built wheel. -RUN /opt/conda/bin/pip install --no-cache-dir \ - triton==2.0.0.dev20221202 \ - https://storage.googleapis.com/ai2-python-wheels/flash_attn/flash_attn-0.2.8%2Bcu117torch1.13.1-cp310-cp310-linux_x86_64.whl +# RUN /opt/conda/bin/conda install -c nvidia cuda-libraries-dev +# RUN /opt/conda/bin/pip install --no-cache-dir \ +# triton==2.0.0.dev20221202 \ +# https://storage.googleapis.com/ai2-python-wheels/flash_attn/flash_attn-0.2.8%2Bcu118torch2.0.0-cp310-cp310-linux_x86_64.whl ENV CUDA_HOME=/opt/conda diff --git a/docker/Dockerfile.gantry b/docker/Dockerfile.gantry index 1387ebe66..14a902870 100644 --- a/docker/Dockerfile.gantry +++ b/docker/Dockerfile.gantry @@ -4,7 +4,7 @@ # To build and push the image to Beaker, run 'make gantry-image'. # To test the image after pushing to Beaker, run 'make gantry-test'. -FROM dolma-base +FROM dolma-torch2-base WORKDIR /stage diff --git a/docker/Dockerfile.lumi b/docker/Dockerfile.lumi index 3c66860af..f01991249 100644 --- a/docker/Dockerfile.lumi +++ b/docker/Dockerfile.lumi @@ -1,16 +1,39 @@ -FROM rocm/dev-ubuntu-22.04:5.4-complete +FROM ubuntu:latest ENV DEBIAN_FRONTEND=noninteractive ENV LC_ALL=C.UTF-8 ENV LANG=C.UTF-8 +# Install various softwares RUN apt-get update RUN apt-get upgrade -y -RUN apt-get install -y python-is-python3 git autoconf python3-dev git vim libtool openjdk-8-jdk-headless xvfb fish build-essential wget parallel s3cmd awscli rocm-libs rccl +RUN apt-get install -y \ + python-is-python3 \ + python3-dev \ + libpython3-all-dev \ + python-dev-is-python3 \ + python3-pip \ + build-essential \ + git \ + autoconf \ + libtool \ + llvm \ + vim \ + fish \ + wget \ + parallel \ + s3cmd \ + awscli \ + htop \ + wget \ + fish -# Fix for Java trying to find assistive techs in headless java -# https://askubuntu.com/questions/695560/assistive-technology-not-found-awterror -RUN sed -i -e '/^assistive_technologies=/s/^/#/' /etc/java-8-openjdk/accessibility.properties +# Install ROCm +RUN wget https://repo.radeon.com/amdgpu-install/5.4.3/ubuntu/jammy/amdgpu-install_5.4.50403-1_all.deb && \ + apt-get install -y ./amdgpu-install_5.4.50403-1_all.deb && \ + amdgpu-install -y --accept-eula --usecase=rocm --no-dkms && \ + rm ./amdgpu-install_5.4.50403-1_all.deb && \ + apt-get install -y rccl rocm-libs # Install MPICH ENV MPICH_VERSION="3.1.4" @@ -34,8 +57,7 @@ ENV LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH # Install torch RUN pip install --upgrade pip -RUN pip install --no-cache-dir "torch<2.0" torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.2 -#RUN pip install --pre --no-cache-dir "torch<2.0" torchvision torchaudio torchtext --extra-index-url https://download.pytorch.org/whl/nightly/rocm5.3 +RUN pip install --no-cache-dir torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.4.2 # Install DeepSpeed RUN pip install --no-cache-dir mpi4py @@ -49,9 +71,10 @@ RUN cd /opt && \ COPY requirements.txt requirements.txt RUN pip install --no-cache-dir -r requirements.txt RUN pip install --no-cache-dir py-spy -RUN pip install wandb --upgrade +RUN pip install --no-cache-dir wandb --upgrade # Cleanup +RUN apt-get autoremove RUN rm -rf /opt/mpich-3.1.4 /opt/aws-ofi-rccl /opt/DeepSpeed RUN apt-get clean RUN pip cache purge diff --git a/docker/Dockerfile.test b/docker/Dockerfile.test index eb301a845..35614df8b 100644 --- a/docker/Dockerfile.test +++ b/docker/Dockerfile.test @@ -4,7 +4,7 @@ # # To build and push the image to Beaker, run 'make test-image'. -FROM dolma-base +FROM dolma-torch2-base COPY scripts/test_entrypoint.sh /entrypoint.sh RUN chmod +x /entrypoint.sh diff --git a/dolma/composer.py b/dolma/composer.py index 89b8f50b6..9f737551e 100644 --- a/dolma/composer.py +++ b/dolma/composer.py @@ -1,40 +1,96 @@ import logging import math import warnings -from typing import Any, Dict, Optional, Tuple, Union +from fnmatch import fnmatch +from typing import Any, Dict, Optional, Set, Tuple, TypedDict, Union import torch +import torch.nn as nn import torch.nn.functional as F from composer.loggers import ConsoleLogger from composer.loggers.logger import format_log_data_value from composer.models import ComposerModel from composer.utils import dist +from torch.utils.data import DataLoader from torchmetrics import Metric from .aliases import BatchDict -from .config import ModelConfig, SchedulerConfig, SchedulerType, TrainConfig +from .config import ( + ModelConfig, + OptimizerType, + SchedulerConfig, + SchedulerType, + TrainConfig, +) +from .data import DataCollator, MemMapDataset from .exceptions import DolmaConfigurationError -from .model import DolmaGPT, DolmaGPTOutput +from .model import Dolma, LayerNormBase +from .optim import DecoupledLionW log = logging.getLogger(__name__) -__all__ = ["ComposerDolmaGPT", "DolmaConsoleLogger", "build_scheduler", "build_algorithm"] +__all__ = [ + "TrainBatchPerplexity", + "ComposerDolmaLM", + "DolmaConsoleLogger", + "build_dataloader", + "build_optimizer", + "build_scheduler", + "build_algorithm", +] -class ComposerDolmaGPT(ComposerModel): - def __init__(self, config: ModelConfig): +class TrainBatchOutput(TypedDict, total=True): + logits: torch.Tensor + """ + The (shifted) logits. + """ + + labels: torch.Tensor + """ + The (shifted) label token IDs. + """ + + loss: torch.Tensor + """ + The cross-entropy loss. + """ + + +class TrainBatchPerplexity(Metric): + """ + A metric for tracking training perplexity on a per-batch basis. + We use this as a training metric instead of composer's built-in + :class:`LanguageCrossEntropy` to avoid recomputing the loss. + """ + + def __init__(self) -> None: + super().__init__(sync_on_compute=False) + self.loss: Optional[torch.Tensor] + + def update(self, loss: torch.Tensor): + self.loss = loss + + def compute(self) -> torch.Tensor: + assert self.loss is not None + return torch.exp(self.loss) + + +class ComposerDolmaLM(ComposerModel): + def __init__(self, model_or_config: Union[Dolma, ModelConfig]): super().__init__() - self.model = DolmaGPT(config) + self.model = Dolma(model_or_config) if isinstance(model_or_config, ModelConfig) else model_or_config + self.config = self.model.config + self.num_fwd_flops = self.model.num_fwd_flops - from composer.metrics.nlp import LanguageCrossEntropy, Perplexity + from composer.metrics.nlp import LanguageCrossEntropy, LanguagePerplexity - self.train_metrics = { - "LanguageCrossEntropy": LanguageCrossEntropy(config.vocab_size), - "Perplexity": Perplexity(), + self.train_metrics: Dict[str, Metric] = { + "Perplexity": TrainBatchPerplexity(), } - self.eval_metrics = { - "LanguageCrossEntropy": LanguageCrossEntropy(config.vocab_size), - "Perplexity": Perplexity(), + self.eval_metrics: Dict[str, Metric] = { + "Perplexity": LanguagePerplexity(), + "CrossEntropy": LanguageCrossEntropy(), } def get_labels(self, batch: BatchDict) -> torch.Tensor: @@ -44,28 +100,29 @@ def get_labels(self, batch: BatchDict) -> torch.Tensor: labels = labels.masked_fill(attention_mask == 0.0, -100) return labels[..., 1:].contiguous() - def forward(self, batch: BatchDict) -> DolmaGPTOutput: - return self.model(**batch) - - def loss(self, outputs: DolmaGPTOutput, batch: BatchDict) -> torch.Tensor: + def forward(self, batch: BatchDict) -> TrainBatchOutput: + logits = self.model(**batch).logits[..., :-1, :].contiguous() labels = self.get_labels(batch) - shift_logits = outputs.logits[..., :-1, :].contiguous() - return F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1), ignore_index=-100) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-100) + return {"logits": logits, "labels": labels, "loss": loss} + + def loss(self, outputs: TrainBatchOutput, batch: BatchDict) -> torch.Tensor: + del batch + return outputs["loss"] - def eval_forward(self, batch: BatchDict, outputs: Optional[DolmaGPTOutput] = None) -> DolmaGPTOutput: + def eval_forward(self, batch: BatchDict, outputs: Optional[TrainBatchOutput] = None) -> TrainBatchOutput: return outputs if outputs is not None else self.forward(batch) - def get_metrics(self, is_train: bool = False) -> Dict[str, "Metric"]: + def get_metrics(self, is_train: bool = False) -> Dict[str, Metric]: return self.train_metrics if is_train else self.eval_metrics - def update_metric(self, batch: BatchDict, outputs: DolmaGPTOutput, metric: "Metric") -> None: - labels = self.get_labels(batch) - shift_logits = outputs.logits[..., :-1, :].contiguous() - metric.update(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)) - - @property - def num_fwd_flops(self): - return self.model.num_fwd_flops + def update_metric(self, batch: BatchDict, outputs: TrainBatchOutput, metric: Metric) -> None: + del batch + if isinstance(metric, TrainBatchPerplexity): + metric.update(outputs["loss"].detach()) + else: + logits, labels = outputs["logits"], outputs["labels"] + metric.update(logits.view(-1, logits.size(-1)), labels.view(-1)) def flops_per_batch(self, batch: BatchDict): # Note: this computation does not take into account padding, and assumes @@ -75,6 +132,16 @@ def flops_per_batch(self, batch: BatchDict): class DolmaConsoleLogger(ConsoleLogger): + metrics_to_log: Set[str] = {"loss/train/total", "trainer/global_step", "metrics/*"} + + def log_metrics(self, metrics: dict[str, float], step: Optional[int] = None) -> None: + del step + # Lazy logging of metrics. + # Stores all metrics logged until they are cleared with a log_to_console call + self.logged_metrics.update( + {k: v for k, v in metrics.items() if any(fnmatch(k, pattern) for pattern in self.metrics_to_log)} + ) + def _log_hparams_to_console(self): if dist.get_local_rank() == 0: log_str = "Config:" @@ -87,6 +154,103 @@ def _log_to_console(self, log_str: str): log.info(log_str) +def build_dataloader(config: TrainConfig, batch_size: int) -> DataLoader: + from composer.utils.dist import get_sampler + + collator = DataCollator.from_train_config(config) + dataset = MemMapDataset.from_train_config(config) + sampler = get_sampler(dataset, shuffle=True, drop_last=config.data.drop_last) + return DataLoader( + dataset, + batch_size=batch_size, + collate_fn=collator, + num_workers=config.data.num_workers, + sampler=sampler, + pin_memory=config.data.pin_memory, + prefetch_factor=config.data.prefetch_factor, + persistent_workers=config.data.persistent_workers, + timeout=config.data.timeout, + ) + + +def build_optimizer( + model, + name: OptimizerType = OptimizerType.decoupled_lionw, + learning_rate: Optional[float] = None, + weight_decay: float = 0.0, + betas: Tuple[float, float] = (0.9, 0.95), + eps: float = 1e-8, +) -> torch.optim.Optimizer: + """ + Get a suitable optimizer for training/fine-tuning. + + :param learning_rate: The learning rate. If not specified, a default learning + rate will calculated according to the equation from the Scaling Laws paper + `0.003239 - 0.0001395 * math.log(N)`, + where `N` is the number of trainable parameters excluding embeddings. + :param weight_decay: The weight decay coefficient. This does not apply to + biases and layernorm/embedding weights, which will have a weight decay + coefficient of 0. + :param kwargs: Other keyword arguments passed to the optimizer. + """ + # Separate out all parameters to those that will and won't experience regularizing weight decay. + decay = set() + no_decay = set() + all_params = {} + num_trainable_non_embedding_weights = 0 + for mn, m in model.named_modules(): + for pn, p in m.named_parameters(): + # NOTE: because named_modules and named_parameters are recursive + # we will see the same tensors p many many times, but doing it this way + # allows us to know which parent module any tensor p belongs to... + if not p.requires_grad: + continue + + fpn = f"{mn}.{pn}" if mn else pn + all_params[fpn] = p + + if pn.endswith("bias"): + # all biases will not be decayed + no_decay.add(fpn) + elif pn.endswith("weight") and isinstance(m, nn.Linear): + # weights of whitelist modules will be weight decayed + decay.add(fpn) + elif pn.endswith("weight") and isinstance(m, (LayerNormBase, nn.LayerNorm, nn.Embedding)): + # weights of blacklist modules will NOT be weight decayed + no_decay.add(fpn) + + if fpn not in {"transformer.wte.weight", "transformer.wpe.weight"}: + num_trainable_non_embedding_weights += p.numel() + + # Validate that we've considered every parameter + inter_params = decay & no_decay + union_params = decay | no_decay + assert len(inter_params) == 0, f"parameters {inter_params} made it into both decay/no_decay sets!" + assert ( + len(all_params.keys() - union_params) == 0 + ), f"parameters {all_params.keys() - union_params} were not separated into either decay/no_decay set!" + + # Create the pytorch optimizer groups. + optim_groups = [ + {"params": [all_params[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay}, + {"params": [all_params[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, + ] + + if learning_rate is None: + learning_rate = 0.003239 - 0.0001395 * math.log(num_trainable_non_embedding_weights) + + if name == OptimizerType.decoupled_lionw: + return DecoupledLionW(optim_groups, lr=learning_rate, betas=betas) + elif name == OptimizerType.decoupled_adamw: + from composer.optim import DecoupledAdamW + + return DecoupledAdamW(optim_groups, lr=learning_rate, betas=betas, eps=eps) + elif name == OptimizerType.adamw: + return torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, eps=eps) + else: + raise NotImplementedError(f"Not sure how to build optimizer '{name}'") + + def build_scheduler(cfg: SchedulerConfig): from composer.optim.scheduler import ( ConstantWithWarmupScheduler, @@ -113,8 +277,6 @@ def build_algorithm(name: str, kwargs: Dict[str, Any]): return algorithms.FusedLayerNorm(**kwargs) elif name == "gated_linear_units": return algorithms.GatedLinearUnits(**kwargs) - elif name == "low_precision_layernorm": - return algorithms.LowPrecisionLayerNorm(**kwargs) else: raise NotImplementedError(f"Not sure how to build algorithm '{name}'") diff --git a/dolma/config.py b/dolma/config.py index ddebb3a2d..1aabacc23 100644 --- a/dolma/config.py +++ b/dolma/config.py @@ -25,6 +25,10 @@ from .exceptions import DolmaConfigurationError __all__ = [ + "ActivationType", + "BlockType", + "CompilerConfig", + "LayerNormType", "ModelConfig", "OptimizerType", "OptimizerConfig", @@ -114,6 +118,40 @@ def asdict(self, exclude: Optional[Iterable[str]] = None) -> Dict[str, Any]: return out +class LayerNormType(StrEnum): + default = "default" + """ + The default LayerNorm implementation, equivalent to PyTorch's built-in version. + """ + + low_precision = "low_precision" + """ + A low-precision version of the default LayerNorm. + """ + + rms = "rms" + """ + An RMSNorm implementation. When using ``torch.compile`` this is + probably the fastest implementation. + """ + + low_precision_rms = "low_precision_rms" + """ + A low-precision version of RMSNorm. + """ + + +class ActivationType(StrEnum): + gelu = "gelu" + relu = "relu" + swiglu = "swiglu" + + +class BlockType(StrEnum): + sequential = "sequential" + parallel = "parallel" + + @dataclass class ModelConfig(BaseConfig): """ @@ -142,9 +180,19 @@ class ModelConfig(BaseConfig): The ratio of the inner MLP dimensionality to ``d_model``. """ + activation_type: ActivationType = ActivationType.swiglu + """ + The activation function to use within the MLP layers. + """ + + block_type: BlockType = BlockType.sequential + """ + The transformer block implementation. + """ + alibi: bool = False """ - If ``True``, use ALiBi embeddings. + If ``True``, use ALiBi embeddings. Mutually exclusive with ``rope``. """ alibi_bias_max: float = 8.0 @@ -152,11 +200,21 @@ class ModelConfig(BaseConfig): Maximum absolute value of ALiBi bias. """ + rope: bool = False + """ + Use rotary positional embeddings (RoPE). Mutually exclusive with ``alibi``. + """ + flash_attention: bool = False """ If ``True``, use ``FlashAttention``. """ + memory_efficient_attention: bool = False + """ + If ``True``, enable memory-efficient attention. + """ + attention_dropout: float = 0.1 """ The dropout probability within the attention modules. @@ -178,6 +236,11 @@ class ModelConfig(BaseConfig): The dropout probability for embeddings. """ + layer_norm_type: LayerNormType = LayerNormType.default + """ + The layernorm implementation to use. + """ + max_sequence_length: int = 1024 """ The maximum input sequence length supported by the model. @@ -195,6 +258,14 @@ class ModelConfig(BaseConfig): Vocabulary size of the model. """ + embedding_size: Optional[int] = 50304 + """ + The number of embeddings, i.e. the number of tokens. If set to ``None`` it will default + to ``vocab_size``. If ``vocab_size`` is not a multiple of 128, setting this to the + next multiple of 128 that's greater than ``vocab_size`` can improve throughput + substantially. + """ + eos_token_id: int = 50256 """ The ID of the end-of-sentence special token. @@ -215,6 +286,12 @@ class ModelConfig(BaseConfig): Standard deviation used when initializing parameters. """ + precision: Optional[str] = None + """ + Precision used to train/evaluate with. You shouldn't set this directly. + See :data:`TrainConfig.precision` instead. + """ + @property def device(self) -> Optional[str]: if self.init_device == "meta" or self.init_device is None: @@ -266,7 +343,7 @@ class DataConfig(BaseConfig): num_workers: int = 0 drop_last: bool = True pin_memory: bool = True - prefetch_factor: int = 2 + prefetch_factor: Optional[int] = 2 persistent_workers: bool = True timeout: int = 0 @@ -299,6 +376,28 @@ class SpeedMonitorConfig(BaseConfig): gpu_flops_available: Optional[Union[float, int]] = None +@dataclass +class CompilerConfig(BaseConfig): + mode: Optional[str] = None + """ + The mode to compile the model in. At the moment this can be "default", + "reduce-overhead" (useful for smaller models/batches), or "max-autotune" + (the fastest for larger models, but takes a long time to compile). + """ + + fullgraph: Optional[bool] = None + """ + Whether it is OK to break model into several subgraphs when compiling. + + If ``None``, ``fullgraph`` will default to ``True`` unless used during FSDP distributed training. + """ + + backend: str = "inductor" + """ + The backend to use. + """ + + @dataclass class TrainConfig(BaseConfig): """ @@ -311,7 +410,7 @@ class TrainConfig(BaseConfig): model: ModelConfig = field(default_factory=ModelConfig) optimizer: OptimizerConfig = field(default_factory=OptimizerConfig) scheduler: SchedulerConfig = field(default_factory=SchedulerConfig) - algorithms: Optional[Dict[str, Dict[str, Any]]] = None + algorithms: Optional[Dict[str, Optional[Dict[str, Any]]]] = None data: DataConfig = field(default_factory=DataConfig) tokenizer: TokenizerConfig = field(default_factory=TokenizerConfig) save_folder: str = "./" @@ -332,6 +431,10 @@ class TrainConfig(BaseConfig): wandb: Optional[WandbConfig] = None speed_monitor: SpeedMonitorConfig = field(default_factory=SpeedMonitorConfig) console_log_interval: Union[str, int] = "1ba" + compile: Optional[CompilerConfig] = None + """ + Settings for compiling the model with ``torch.compile()``. + """ @property def device(self) -> Optional[str]: diff --git a/dolma/data/__init__.py b/dolma/data/__init__.py index f00334920..782c9f332 100644 --- a/dolma/data/__init__.py +++ b/dolma/data/__init__.py @@ -1,5 +1,4 @@ from .collator import DataCollator -from .dataloader import build_dataloader from .memmap_dataset import MemMapDataset -__all__ = ["MemMapDataset", "DataCollator", "build_dataloader"] +__all__ = ["MemMapDataset", "DataCollator"] diff --git a/dolma/data/dataloader.py b/dolma/data/dataloader.py deleted file mode 100644 index 9a2a690d4..000000000 --- a/dolma/data/dataloader.py +++ /dev/null @@ -1,26 +0,0 @@ -from torch.utils.data import DataLoader - -from ..config import TrainConfig -from .collator import DataCollator -from .memmap_dataset import MemMapDataset - -__all__ = ["build_dataloader"] - - -def build_dataloader(config: TrainConfig, batch_size: int) -> DataLoader: - from composer.utils.dist import get_sampler - - collator = DataCollator.from_train_config(config) - dataset = MemMapDataset.from_train_config(config) - sampler = get_sampler(dataset, shuffle=True, drop_last=config.data.drop_last) - return DataLoader( - dataset, - batch_size=batch_size, - collate_fn=collator, - num_workers=config.data.num_workers, - sampler=sampler, - pin_memory=config.data.pin_memory, - prefetch_factor=config.data.prefetch_factor, - persistent_workers=config.data.persistent_workers, - timeout=config.data.timeout, - ) diff --git a/dolma/model.py b/dolma/model.py index 05519a01f..98732026b 100644 --- a/dolma/model.py +++ b/dolma/model.py @@ -4,205 +4,408 @@ [minGPT](https://github.com/karpathy/minGPT.git) """ +from __future__ import annotations + import math from abc import abstractmethod from typing import NamedTuple, Optional, cast import torch +import torch.backends.cuda import torch.nn as nn import torch.nn.functional as F -from einops import rearrange +from torch import einsum + +from .config import ActivationType, BlockType, LayerNormType, ModelConfig +from .exceptions import DolmaConfigurationError + +__all__ = [ + "LayerNormBase", + "LayerNorm", + "RMSLayerNorm", + "RotaryEmbedding", + "Activation", + "GELU", + "ReLU", + "SwiGLU", + "DolmaBlock", + "DolmaSequentialBlock", + "DolmaParallelBlock", + "Dolma", +] + + +class LayerNormBase(nn.Module): + def __init__(self, config: ModelConfig): + super().__init__() + self.config = config + + @abstractmethod + def forward(self, x: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + @classmethod + def build(cls, config: ModelConfig) -> LayerNormBase: + if config.layer_norm_type == LayerNormType.default: + return LayerNorm(config, low_precision=False) + elif config.layer_norm_type == LayerNormType.low_precision: + return LayerNorm(config, low_precision=True) + elif config.layer_norm_type == LayerNormType.rms: + return RMSLayerNorm(config, low_precision=False) + elif config.layer_norm_type == LayerNormType.low_precision_rms: + return RMSLayerNorm(config, low_precision=True) + else: + raise NotImplementedError(f"Not sure how to handle '{config.layer_norm_type}' LayerNorm type") + + def _cast_if_autocast_enabled(self, tensor: torch.Tensor) -> torch.Tensor: + if torch.is_autocast_enabled(): + if tensor.device.type == "cuda": + dtype = torch.get_autocast_gpu_dtype() + elif tensor.device.type == "cpu": + dtype = torch.get_autocast_cpu_dtype() + else: + raise NotImplementedError() + return tensor.to(dtype=dtype) + return tensor -from .config import ModelConfig -__all__ = ["TorchAttention", "GPTMLP", "GPTBlock", "DolmaGPT"] +class LayerNorm(LayerNormBase): + """ + The default :class:`LayerNorm` implementation which can optionally run in low precision. + """ + + def __init__(self, config: ModelConfig, low_precision: bool = False): + super().__init__(config) + self.normalized_shape = (config.d_model,) + self.eps = 1e-05 + self.weight = nn.Parameter(torch.ones(self.normalized_shape, device=config.init_device)) + self.bias = nn.Parameter(torch.zeros(self.normalized_shape, device=config.init_device)) + self.low_precision = low_precision + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.low_precision: + module_device = x.device + downcast_x = self._cast_if_autocast_enabled(x) + downcast_weight = ( + self._cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight + ) + downcast_bias = self._cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias + with torch.autocast(enabled=False, device_type=module_device.type): + return F.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps) + else: + return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + + +class RMSLayerNorm(LayerNorm): + """ + RMS layer norm, a simplified :class:`LayerNorm` implementation that can optionally run + in low-precision. + """ + def __init__(self, config: ModelConfig, low_precision: bool = False): + super().__init__(config) + self.eps = 1e-08 + self.weight = nn.Parameter(torch.ones(self.config.d_model)) + if self.config.include_bias: + self.bias = nn.Parameter(torch.zeros(self.config.d_model)) + else: + self.register_parameter("bias", None) + self.low_precision = low_precision + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.low_precision: + module_device = x.device + downcast_x = self._cast_if_autocast_enabled(x) + downcast_weight = self._cast_if_autocast_enabled(self.weight) + downcast_bias = self._cast_if_autocast_enabled(self.bias) if self.config.include_bias else None + with torch.autocast(enabled=False, device_type=module_device.type): + return self.rms_norm(downcast_x, downcast_weight, downcast_bias) + else: + return self.rms_norm(x, self.weight, self.bias if self.config.include_bias else None) + + def rms_norm(self, x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: + norm_x = x.norm(2, dim=-1, keepdim=True) + + rms_x = norm_x * self.config.d_model ** (-1.0 / 2) + x_normed = x / (rms_x + self.eps) + + if bias is not None: + return weight * x_normed + self.bias + else: + return weight * x_normed + + +class RotaryEmbedding(nn.Module): + """ + [Rotary positional embeddings (RoPE)](https://arxiv.org/abs/2104.09864). + """ -class DolmaAttentionBase(nn.Module): def __init__(self, config: ModelConfig): super().__init__() - assert config.d_model % config.n_heads == 0 - self.n_heads = config.n_heads - self.d_model = config.d_model + dim = config.d_model // config.n_heads + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=config.init_device).float() / dim)) + self.register_buffer("inv_freq", inv_freq) - # key, query, value projections for all heads, but in a batch - self.c_attn = nn.Linear( - config.d_model, 3 * config.d_model, bias=config.include_bias, device=config.init_device - ) - # for param init fn - self.c_attn._fused = (0, (self.d_model, 2 * self.d_model)) # type: ignore + def forward(self, max_seq_len, *, device): + seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype) # type: ignore + freqs = einsum("i , j -> i j", seq, self.inv_freq) + return torch.cat((freqs, freqs), dim=-1) - # output projection - self.c_proj = nn.Linear( - config.d_model, config.d_model, bias=config.include_bias, device=config.init_device - ) - # for param init fn - self.c_proj._is_residual = True # type: ignore - # regularization - self.attn_dropout = nn.Dropout(config.attention_dropout) - self.resid_dropout = nn.Dropout(config.residual_dropout) +def rotate_half(x: torch.Tensor) -> torch.Tensor: + B, nh, T, hs = x.size() + x = x.view(B, nh, T, 2, hs // 2) + x1, x2 = x.unbind(dim=-2) + return torch.cat((-x2, x1), dim=-1) - # optional layer norm for keys and queries. - self.k_ln: Optional[nn.LayerNorm] = None - self.q_ln: Optional[nn.LayerNorm] = None - if config.attention_layer_norm: - self.k_ln = nn.LayerNorm(self.d_model, device=config.init_device) - self.q_ln = nn.LayerNorm(self.d_model, device=config.init_device) - @abstractmethod - def forward( - self, - x: torch.FloatTensor, - attention_bias: Optional[torch.FloatTensor] = None, - ) -> torch.FloatTensor: - raise NotImplementedError +def apply_rotary_pos_emb(pos: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + out = (t * pos.cos()) + (rotate_half(t) * pos.sin()) + return out.to(t.dtype) -class TorchAttention(DolmaAttentionBase): +class Activation(nn.Module): def __init__(self, config: ModelConfig): - super().__init__(config) - - def forward( - self, - x: torch.FloatTensor, - attention_bias: Optional[torch.FloatTensor] = None, - ) -> torch.FloatTensor: - """ - :param x: A tensor of shape `(batch_size, seq_len, d_model)`. - :param attention_bias: A tensor of shape `(batch_size, n_heads, seq_len, seq_len)` - or an equivalently broadcastable shape. This is used to introduce causal or other biases - and it is simply added to the attention scores before the softmax. - """ - B, T, C = x.size() # batch size, sequence length, embedding dimensionality (d_model) + super().__init__() + self.config = config - # Calculate query, key, values for all heads in batch. - # shape (all): (B, T, C) - q, k, v = self.c_attn(x).split(self.d_model, dim=2) + @abstractmethod + def forward(self, x: torch.Tensor) -> torch.Tensor: + raise NotImplementedError - # Optionally apply layer norm to keys and queries. - if self.k_ln is not None and self.q_ln is not None: - dtype = k.dtype - k = self.k_ln(k).to(dtype=dtype) - q = self.q_ln(q).to(dtype=dtype) + @property + @abstractmethod + def output_multiplier(self) -> float: + raise NotImplementedError - # Move head forward to be next to the batch dim. - # shape (all): (B, nh, T, hs) - k = k.view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2) - q = q.view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2) - v = v.view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2) + @classmethod + def build(cls, config: ModelConfig) -> Activation: + if config.activation_type == ActivationType.gelu: + return cast(Activation, GELU(approximate="none")) + elif config.activation_type == ActivationType.relu: + return cast(Activation, ReLU(inplace=False)) + elif config.activation_type == ActivationType.swiglu: + return SwiGLU(config) + else: + raise NotImplementedError(f"not sure how to handle activation type '{config.activation_type}'") - # Self-attention: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) - att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) - # Apply bias. - if attention_bias is not None: - att = att + attention_bias[:, :, :T, :T] +class GELU(nn.GELU): + @property + def output_multiplier(self) -> float: + return 1.0 - # Apply softmax and dropout. - att = F.softmax(att, dim=-1) - att = self.attn_dropout(att) - # Get head outputs. - y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) +class ReLU(nn.ReLU): + @property + def output_multiplier(self) -> float: + return 1.0 - # Re-assemble all head outputs side by side. - y = y.transpose(1, 2).contiguous().view(B, T, C) - # Apply output projection. - y = self.resid_dropout(self.c_proj(y)) +class SwiGLU(Activation): + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, gate = x.chunk(2, dim=-1) + return F.silu(gate) * x - return y + @property + def output_multiplier(self) -> float: + return 0.5 -class FlashAttention(DolmaAttentionBase): +class DolmaBlock(nn.Module): """ - Triton implementation of FlashAttention. + A base class for transformer block implementations. """ def __init__(self, config: ModelConfig): - from flash_attn import flash_attn_triton # type: ignore + super().__init__() + self.config = config + assert config.d_model % config.n_heads == 0 - super().__init__(config) + # Dropout. + self.dropout = nn.Dropout(config.residual_dropout) + + # Layer norms. + self.norm = LayerNorm.build(config) + self.k_norm: Optional[LayerNormBase] = None + self.q_norm: Optional[LayerNormBase] = None + if config.attention_layer_norm: + self.k_norm = LayerNormBase.build(config) + self.q_norm = LayerNormBase.build(config) - assert self.d_model / self.n_heads in {64, 128}, "FlashAttention requires head dim of 64 or 128 for now" - assert config.attention_dropout == 0, "FlashAttention does not support attention dropout for now" - self.flash_attn_qkvpacked_func = flash_attn_triton.flash_attn_qkvpacked_func + # Activation function. + self.act = Activation.build(config) + assert (self.act.output_multiplier * config.mlp_ratio * config.d_model) % 1 == 0 - def forward( - self, x: torch.FloatTensor, attention_bias: Optional[torch.FloatTensor] = None - ) -> torch.FloatTensor: - """ - :param x: A tensor of shape `(batch_size, seq_len, d_model)`. - :param attention_bias: A tensor of shape `(batch_size, n_heads, seq_len, seq_len)` - or an equivalently broadcastable shape. This is used to introduce causal or other biases - and it is simply added to the attention scores before the softmax. - """ - # Calculate query, key, values for all heads in batch. - # shape: (batch_size, seq_length, d_model * 3) - qkv = self.c_attn(x) + # Attention output projection. + self.attn_out = nn.Linear( + config.d_model, config.d_model, bias=config.include_bias, device=config.init_device + ) + + # Feed-forward output projection. + self.ff_out = nn.Linear( + int(self.act.output_multiplier * config.mlp_ratio * config.d_model), + config.d_model, + bias=config.include_bias, + device=config.init_device, + ) + self.ff_out._is_residual = True # type: ignore + + # Rotary embeddings. + if self.config.rope: + self.rotary_emb = RotaryEmbedding(config) + self.register_buffer( + "pos_emb", self.rotary_emb(config.max_sequence_length, device=config.init_device), persistent=False + ) + + def get_rotary_embedding(self, seq_len: int, device: Optional[torch.device]) -> torch.Tensor: + if self.pos_emb is not None and self.pos_emb.shape[-2] >= seq_len: # type: ignore + return self.pos_emb[:seq_len] # type: ignore + + pos_emb = self.rotary_emb(seq_len, device=device) + self.register_buffer("pos_emb", pos_emb, persistent=False) + return pos_emb + + def attention( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attention_bias: Optional[torch.FloatTensor] = None + ) -> torch.Tensor: + B, T, C = q.size() # batch size, sequence length, d_model + dtype = k.dtype # Optionally apply layer norm to keys and queries. - if self.q_ln is not None and self.k_ln is not None: - # Applying layernorm to qk - dtype = qkv.dtype - q, k, v = qkv.split(self.d_model, dim=-1) - q = self.q_ln(q).to(dtype=dtype) - k = self.k_ln(k).to(dtype=dtype) - qkv = torch.cat([q, k, v], dim=-1) + if self.q_norm is not None and self.k_norm is not None: + q = self.q_norm(q).to(dtype=dtype) + k = self.k_norm(k).to(dtype=dtype) - # Apply inner attention function. - qkv = rearrange(qkv, "b s (t h d) -> b s t h d", t=3, h=self.n_heads) - y = self.flash_attn_qkvpacked_func(qkv, attention_bias) + # Move head forward to be next to the batch dim. + # shape (all): (B, nh, T, hs) + q = q.view(B, T, self.config.n_heads, C // self.config.n_heads).transpose(1, 2) + k = k.view(B, T, self.config.n_heads, C // self.config.n_heads).transpose(1, 2) + v = v.view(B, T, self.config.n_heads, C // self.config.n_heads).transpose(1, 2) + + if self.config.rope: + # Apply rotary embeddings. + positions = self.get_rotary_embedding(T, q.device) + q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k)) + + # Get the attention scores. + # shape: (B, nh, T, hs) + att = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None if attention_bias is None else attention_bias.to(dtype=dtype), + dropout_p=0.0 if not self.training else self.config.attention_dropout, + is_causal=attention_bias is None, + ) - # Re-assemble all head outputs side by side. - y = rearrange(y, "b s h d -> b s (h d)") + # Re-assemble all head outputs side-by-side. + att = att.transpose(1, 2).contiguous().view(B, T, C) # Apply output projection. - y = self.resid_dropout(self.c_proj(y)) + return self.attn_out(att) - return y + @abstractmethod + def forward( + self, + x: torch.Tensor, + attention_bias: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + raise NotImplementedError + @classmethod + def build(cls, config: ModelConfig) -> DolmaBlock: + if config.block_type == BlockType.sequential: + return DolmaSequentialBlock(config) + elif config.block_type == BlockType.parallel: + return DolmaParallelBlock(config) + else: + raise NotImplementedError(f"not sure how to handle block type '{config.block_type}'") + + +class DolmaSequentialBlock(DolmaBlock): + """ + This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). + """ -class GPTMLP(nn.Module): def __init__(self, config: ModelConfig): - super().__init__() - self.c_fc = nn.Linear( - config.d_model, config.mlp_ratio * config.d_model, bias=config.include_bias, device=config.init_device + super().__init__(config) + # Attention input projection. Projects x -> (q, k, v) + self.att_proj = nn.Linear( + config.d_model, 3 * config.d_model, bias=config.include_bias, device=config.init_device ) - self.act = nn.GELU(approximate="none") - self.c_proj = nn.Linear( - config.mlp_ratio * config.d_model, config.d_model, bias=config.include_bias, device=config.init_device + self.att_proj._fused = (0, (self.config.d_model, 2 * self.config.d_model)) # type: ignore + # Feed-forward input projection. + self.ff_proj = nn.Linear( + config.d_model, config.mlp_ratio * config.d_model, bias=config.include_bias, device=config.init_device ) - self.c_proj._is_residual = True # type: ignore - self.dropout = nn.Dropout(config.residual_dropout) - def forward(self, x): - return self.dropout(self.c_proj(self.act(self.c_fc(x)))) + def forward( + self, + x: torch.Tensor, + attention_bias: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + # Get query, key, value projections. + # shape (all): (batch_size, seq_len, d_model) + q, k, v = self.att_proj(self.norm(x)).split(self.config.d_model, dim=2) + + # Add attention scores. + # shape: (B, T, C) + x = x + self.dropout(self.attention(q, k, v, attention_bias)) + # Add feed-forward projection. + # shape: (batch_size, seq_len, d_model) + x = x + self.dropout(self.ff_out(self.act(self.ff_proj(self.norm(x))))) + + return x + + +class DolmaParallelBlock(DolmaBlock): + """ + This is a transformer block where the output is computed as ``MLP(LN(x)) + Attention(LN(x))`` + as in the PaLM architecture, as opposed to the typical ``MLP(LN(x + Attention(LN(x))))`` + as in :class:`DolmaSequentialBlock` (ignoring some skip connections). + + The decoupling of the MLP and Attention functions allow us to fuse the separate input projections + into a single linear layer to increase throughput. In this configuration it's also straight-forward + to fuse the output projections, but we found that didn't help. + """ -class GPTBlock(nn.Module): def __init__(self, config: ModelConfig): - super().__init__() - self.config = config - self.ln_1 = nn.LayerNorm(config.d_model, device=config.init_device) - self.attn: DolmaAttentionBase = ( - FlashAttention(config) if config.flash_attention else TorchAttention(config) + super().__init__(config) + # Fused attention and feed-forward projection. + # NOTE: we could also fuse the attention and feed-forward output projections + # but we found that didn't help, possibly because of the overhead of joining the `att` + # and `ff` activations together. + # See https://github.com/allenai/LLM/pull/79 for details. + self.fused_dims = (config.d_model, config.d_model, config.d_model, config.mlp_ratio * config.d_model) + self.fused_attn_ff_proj = nn.Linear( + config.d_model, sum(self.fused_dims), bias=config.include_bias, device=config.init_device ) - self.ln_2 = nn.LayerNorm(config.d_model, device=config.init_device) - self.mlp = GPTMLP(config) + self.fused_attn_ff_proj._fused = (0, self.fused_dims) # type: ignore def forward( self, x: torch.Tensor, attention_bias: Optional[torch.FloatTensor] = None, ) -> torch.Tensor: - x = x + self.attn(self.ln_1(x), attention_bias=attention_bias) - x = x + self.mlp(self.ln_2(x)) - return x + # Get query, key, value, and feed-forward projections. + # shape of q, k, v: (batch_size, seq_len, d_model) + # shape of ff: (batch_size, seq_len, mlp_ratio x d_model) + q, k, v, ff = self.fused_attn_ff_proj(self.norm(x)).split(self.fused_dims, dim=-1) + + # Get attention scores. + # shape: (B, T, C) + att = self.attention(q, k, v, attention_bias) + # Apply output projections (and activation function) and sum the results. + # We keep these projections separate because we found that we got better throughput this + # way compared to fusing them. + return x + self.dropout(self.ff_out(self.act(ff))) + self.dropout(att) -class DolmaGPTOutput(NamedTuple): + +class DolmaOutput(NamedTuple): logits: torch.FloatTensor """ A tensor of shape `(batch_size, seq_len, vocab_size)` representing the log probabilities @@ -210,19 +413,42 @@ class DolmaGPTOutput(NamedTuple): """ -class DolmaGPT(nn.Module): +class Dolma(nn.Module): def __init__(self, config: ModelConfig, init_params: bool = True): super().__init__() self.config = config + + # Validate config. + if self.config.alibi and self.config.flash_attention: + raise DolmaConfigurationError("ALiBi is currently not supported with FlashAttention") + + if self.config.alibi and self.config.rope: + raise DolmaConfigurationError("ALiBi and RoPE are mutually exclusive") + + if self.config.embedding_size is not None and self.config.embedding_size != self.config.vocab_size: + if self.config.embedding_size < self.config.vocab_size: + raise DolmaConfigurationError("embedding size should be at least as big as vocab size") + elif self.config.embedding_size % 128 != 0: + import warnings + + warnings.warn( + "Embedding size is not a multiple of 128! This could hurt throughput performance.", UserWarning + ) + + torch.backends.cuda.enable_flash_sdp(self.config.flash_attention) + torch.backends.cuda.enable_mem_efficient_sdp(self.config.memory_efficient_attention) + self.transformer = nn.ModuleDict( dict( - wte=nn.Embedding(config.vocab_size, config.d_model, device=config.init_device), + wte=nn.Embedding( + config.embedding_size or config.vocab_size, config.d_model, device=config.init_device + ), emb_drop=nn.Dropout(config.embedding_dropout), - blocks=nn.ModuleList([GPTBlock(config) for _ in range(config.n_layers)]), - ln_f=nn.LayerNorm(config.d_model, device=config.init_device), + blocks=nn.ModuleList([DolmaBlock.build(config) for _ in range(config.n_layers)]), + ln_f=LayerNorm.build(config), ) ) - if not self.config.alibi: + if not (self.config.alibi or self.config.rope): self.transformer.update( {"wpe": nn.Embedding(config.max_sequence_length, config.d_model, device=config.init_device)} ) @@ -230,6 +456,25 @@ def __init__(self, config: ModelConfig, init_params: bool = True): self.apply(self.param_init_fn) self.__num_fwd_flops = None + # Initialize attention bias buffers up front since calling `register_buffer` + # while compiling will cause a break in the graph. + if self.config.alibi: + self.causal_attention_bias + self.alibi_attention_bias + + @property + def buffer_dtype(self) -> torch.dtype: + """ + For some reason when we use :func:`torch.compile()` and AMP, we have to create the + attention bias buffers with the right data type. + """ + if self.config.precision == "amp_bf16": + return torch.bfloat16 + elif self.config.precision == "amp_fp16": + return torch.float16 + else: + return torch.float + @property def causal_attention_bias(self) -> torch.FloatTensor: if not hasattr(self, "_causal_attention_bias"): @@ -245,10 +490,12 @@ def causal_attention_bias(self) -> torch.FloatTensor: att_bias.masked_fill_(att_bias == 1, float("-inf")) self.register_buffer( "_causal_attention_bias", - att_bias.view(1, 1, self.config.max_sequence_length, self.config.max_sequence_length), + att_bias.to(dtype=self.buffer_dtype).view( + 1, 1, self.config.max_sequence_length, self.config.max_sequence_length + ), persistent=False, ) - return cast(torch.FloatTensor, self._causal_attention_bias) + return self._causal_attention_bias # type: ignore[return-type] @property def alibi_attention_bias(self) -> torch.FloatTensor: @@ -270,15 +517,15 @@ def alibi_attention_bias(self) -> torch.FloatTensor: # shape: (1, n_heads, seq_len, seq_len) alibi_bias = alibi_bias * (1.0 / (2 ** m.view(1, self.config.n_heads, 1, 1))) - self.register_buffer("_alibi_attention_bias", alibi_bias, persistent=False) - return cast(torch.FloatTensor, self._alibi_attention_bias) + self.register_buffer("_alibi_attention_bias", alibi_bias.to(dtype=self.buffer_dtype), persistent=False) + return self._alibi_attention_bias # type: ignore[return-type] def forward( self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None, attention_bias: Optional[torch.Tensor] = None, - ) -> DolmaGPTOutput: + ) -> DolmaOutput: """ :param input_ids: A tensor of shape `(batch_size, seq_len)`. :param attention_mask: A tensor of shape `(batch_size, seq_len)` that indicates @@ -311,7 +558,7 @@ def forward( # shape: (batch_size, seq_len, d_model) x = self.transformer.wte(input_ids) # type: ignore - if not self.config.alibi: + if not (self.config.alibi or self.config.rope): # Get positional embeddings. # shape: (1, seq_len) pos = torch.arange(0, seq_len, dtype=torch.long, device=input_ids.device).unsqueeze(0) @@ -326,27 +573,28 @@ def forward( # Transform the attention mask into what the blocks expect. if attention_mask is not None: # shape: (batch_size, 1, 1, seq_len) - attention_mask = attention_mask.to(dtype=torch.float).view(batch_size, -1)[:, None, None, :] + attention_mask = attention_mask.to(dtype=x.dtype).view(batch_size, -1)[:, None, None, :] attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min attention_mask.masked_fill_(attention_mask == 1.0, float("-inf")) - # Default to causal attention bias. - attention_bias = cast( - torch.Tensor, attention_bias if attention_bias is not None else self.causal_attention_bias - ) - if attention_bias.dtype in (torch.int8, torch.bool): - attention_bias = attention_bias.to(dtype=torch.float) - attention_bias.masked_fill_(attention_bias == 0.0, float("-inf")) + # Merge attention mask with attention bias. + if attention_bias is not None or attention_mask is not None or self.config.alibi: + if attention_bias is None: + # Default to causal attention bias. + attention_bias = self.causal_attention_bias + elif attention_bias.dtype in (torch.int8, torch.bool): + attention_bias = attention_bias.to(dtype=x.dtype) + attention_bias.masked_fill_(attention_bias == 0.0, float("-inf")) - attention_bias = attention_bias[:, :, :seq_len, :seq_len] + attention_bias = attention_bias[:, :, :seq_len, :seq_len] - # Add in the masking bias. - if attention_mask is not None: - attention_bias = attention_bias + attention_mask + # Add in the masking bias. + if attention_mask is not None: + attention_bias = attention_bias + attention_mask - if self.config.alibi: - # Add in ALiBi attention bias. - attention_bias = attention_bias + self.alibi_attention_bias[:, :, :seq_len, :seq_len] + if self.config.alibi: + # Add in ALiBi attention bias. + attention_bias = attention_bias + self.alibi_attention_bias[:, :, :seq_len, :seq_len].to(x.dtype) # Apply blocks one-by-one. for block in self.transformer.blocks: # type: ignore @@ -361,18 +609,18 @@ def forward( # shape: (batch_size, seq_len, vocab_size) logits = F.linear(x, self.transformer.wte.weight, None) # type: ignore - return DolmaGPTOutput(logits=cast(torch.FloatTensor, logits)) + return DolmaOutput(logits=logits) # type: ignore[arg-type] def fsdp_wrap_fn(self, module): - return isinstance(module, GPTBlock) + return isinstance(module, DolmaBlock) def activation_checkpointing_fn(self, module): - return isinstance(module, GPTBlock) + return isinstance(module, DolmaBlock) def param_init_fn(self, module): from functools import partial - init_fn = partial(torch.nn.init.normal_, mean=0.0, std=self.config.init_std) + init_fn = partial(nn.init.normal_, mean=0.0, std=self.config.init_std) def fused_init_fn(module): # Parameter initialization is often based on the parameters shape. @@ -400,23 +648,23 @@ def fused_init_fn(module): init_fn(module.weight) if module.bias is not None: - torch.nn.init.zeros_(module.bias) + nn.init.zeros_(module.bias) if getattr(module, "_is_residual", False): with torch.no_grad(): module.weight.div_(math.sqrt(2 * self.config.n_layers)) if module.bias is not None: - torch.nn.init.zeros_(module.bias) + nn.init.zeros_(module.bias) # Embedding if isinstance(module, nn.Embedding): init_fn(module.weight) # LayerNorm - if isinstance(module, nn.LayerNorm): - torch.nn.init.zeros_(module.bias) + if isinstance(module, (nn.LayerNorm, LayerNorm, RMSLayerNorm)): torch.nn.init.ones_(module.weight) + torch.nn.init.zeros_(module.bias) def num_params(self, include_embedding: bool = True) -> int: """ @@ -434,7 +682,7 @@ def num_params(self, include_embedding: bool = True) -> int: def num_fwd_flops(self): if self.__num_fwd_flops: return self.__num_fwd_flops - n_params = sum(p.numel() for p in self.parameters()) + n_params = self.num_params() # the number of parameters is approximately the number of multiply-accumulates (MAC) in the network # each MAC has 2 FLOPs - we multiply by 2 ie 2 * n_param # this gets us FLOPs / token diff --git a/dolma/optim.py b/dolma/optim.py index e624ae8de..530604c56 100644 --- a/dolma/optim.py +++ b/dolma/optim.py @@ -1,39 +1,13 @@ -import math import warnings from typing import Callable, Optional, Tuple, cast import torch -import torch.nn as nn -import torch.nn.functional as F from torch.optim.optimizer import Optimizer -from .config import OptimizerType - -__all__ = ["DecoupledLionW", "build_optimizer"] +__all__ = ["DecoupledLionW"] class DecoupledLionW(Optimizer): - metric_functions = { - "l2_norm/moment": lambda param, optim_state, step_tensor: torch.linalg.vector_norm( # pyright: ignore - optim_state["exp_avg"] - ), - "l2_norm/param": lambda param, optim_state, step_tensor: torch.linalg.vector_norm( # pyright: ignore - param.data - ), - "l2_norm/update": lambda param, optim_state, step_tensor: torch.linalg.vector_norm( # pyright: ignore - step_tensor - ), - "l2_norm/grad": lambda param, optim_state, step_tensor: torch.linalg.vector_norm( # pyright: ignore - param.grad - ), - "cosine/update_grad": lambda param, optim_state, step_tensor: F.cosine_similarity( # pyright: ignore - param.grad.flatten(), step_tensor.flatten(), dim=0 - ), - "cosine/moment_grad": lambda param, optim_state, step_tensor: F.cosine_similarity( # pyright: ignore - param.grad.flatten(), optim_state["exp_avg"].flatten(), dim=0 - ), - } - def __init__( self, params, @@ -107,151 +81,3 @@ def step(self, closure: Optional[Callable] = None): self.lionw(p, grad, exp_avg, lr, initial_lr, wd, beta1, beta2) return loss - - def dist_reduce_metrics(self, optimizer_metrics): - from composer.utils import dist - - for metric in optimizer_metrics: - if metric.startswith("l2_norm"): - reduced = optimizer_metrics[metric] - if dist.get_world_size() > 1: - dist.all_reduce(reduced, reduce_operation="SUM") - - optimizer_metrics[metric] = math.sqrt(reduced) - elif metric.startswith("cosine"): - reduced = optimizer_metrics[metric] - if dist.get_world_size() > 1: - dist.all_reduce(reduced, reduce_operation="SUM") - - _, vectors, layer = tuple(metric.split("/")) - - A, B = tuple(vectors.split("_")) - - A_reduced_norm = optimizer_metrics[f"l2_norm/{A}/{layer}"] - B_reduced_norm = optimizer_metrics[f"l2_norm/{B}/{layer}"] - optimizer_metrics[metric] = reduced / (A_reduced_norm * B_reduced_norm) - else: - reduced = optimizer_metrics[metric] - if dist.get_world_size() > 1: - dist.all_reduce(reduced, reduce_operation="SUM") - optimizer_metrics[metric] = reduced / dist.get_world_size() - - return optimizer_metrics - - def pre_reduce_metrics(self, optimizer_metrics): - """Preprocess metrics to reduce across ranks correctly.""" - # Sort L2 norms first so they are squared before other metrics, which depend on squared values - metrics = optimizer_metrics.keys() - metrics = sorted(metrics, key=lambda metric: 0 if "l2_norm" in metric else 1) - for metric in metrics: - if metric.startswith("l2_norm"): - # L2 norms need to be squared, before they are reduced via summation - optimizer_metrics[metric] = optimizer_metrics[metric] ** 2 - elif metric.startswith("cosine"): - _, vectors, layer = tuple(metric.split("/")) - - A, B = tuple(vectors.split("_")) - - # L2 norm would've been squared in previous branch - A_rank_subset_norm = math.sqrt(optimizer_metrics[f"l2_norm/{A}/{layer}"]) - B_rank_subset_norm = math.sqrt(optimizer_metrics[f"l2_norm/{B}/{layer}"]) - - optimizer_metrics[metric] *= A_rank_subset_norm * B_rank_subset_norm - - return optimizer_metrics - - def report_per_parameter_metrics(self, param: torch.Tensor, name: str, optimizer_metrics: dict): - lr = self.param_groups[0]["lr"] - weight_decay = self.param_groups[0]["weight_decay"] - initial_lr = self.param_groups[0]["initial_lr"] - - beta1, _ = self.param_groups[0]["betas"] - if param in self.state: - param_optim_state = self.state[param] - step_tensor = param_optim_state["exp_avg"].clone().lerp_(param.grad, 1 - beta1).sign_().mul_(lr) - decay_factor = (lr / initial_lr) if initial_lr else 1.0 - step_tensor.add_(param, alpha=-weight_decay * decay_factor) - for metric in self.metric_functions: - optimizer_metrics[f"{metric}/{name}"] = self.metric_functions[metric]( - param, param_optim_state, step_tensor - ) - - return optimizer_metrics - - -def build_optimizer( - model, - name: OptimizerType = OptimizerType.decoupled_lionw, - learning_rate: Optional[float] = None, - weight_decay: float = 0.0, - betas: Tuple[float, float] = (0.9, 0.95), - eps: float = 1e-8, -) -> torch.optim.Optimizer: - """ - Get a suitable AdamW optimizer for training/fine-tuning. - - :param learning_rate: The learning rate. If not specified, a default learning - rate will calculated according to the equation from the Scaling Laws paper - `0.003239 - 0.0001395 * math.log(N)`, - where `N` is the number of trainable parameters excluding embeddings. - :param weight_decay: The weight decay coefficient. This does not apply to - biases and layernorm/embedding weights, which will have a weight decay - coefficient of 0. - :param kwargs: Other keyword arguments passed to torch's `AdamW` optimizer. - """ - # Separate out all parameters to those that will and won't experience regularizing weight decay. - decay = set() - no_decay = set() - all_params = {} - num_trainable_non_embedding_weights = 0 - for mn, m in model.named_modules(): - for pn, p in m.named_parameters(): - # NOTE: because named_modules and named_parameters are recursive - # we will see the same tensors p many many times, but doing it this way - # allows us to know which parent module any tensor p belongs to... - if not p.requires_grad: - continue - - fpn = f"{mn}.{pn}" if mn else pn - all_params[fpn] = p - - if pn.endswith("bias"): - # all biases will not be decayed - no_decay.add(fpn) - elif pn.endswith("weight") and isinstance(m, nn.Linear): - # weights of whitelist modules will be weight decayed - decay.add(fpn) - elif pn.endswith("weight") and isinstance(m, (nn.LayerNorm, nn.Embedding)): - # weights of blacklist modules will NOT be weight decayed - no_decay.add(fpn) - - if fpn not in {"transformer.wte.weight", "transformer.wpe.weight"}: - num_trainable_non_embedding_weights += p.numel() - - # Validate that we've considered every parameter - inter_params = decay & no_decay - union_params = decay | no_decay - assert len(inter_params) == 0, f"parameters {inter_params} made it into both decay/no_decay sets!" - assert ( - len(all_params.keys() - union_params) == 0 - ), f"parameters {all_params.keys() - union_params} were not separated into either decay/no_decay set!" - - # Create the pytorch optimizer groups. - optim_groups = [ - {"params": [all_params[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay}, - {"params": [all_params[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, - ] - - if learning_rate is None: - learning_rate = 0.003239 - 0.0001395 * math.log(num_trainable_non_embedding_weights) - - if name == OptimizerType.decoupled_lionw: - return DecoupledLionW(optim_groups, lr=learning_rate, betas=betas) - elif name == OptimizerType.decoupled_adamw: - from composer.optim import DecoupledAdamW - - return DecoupledAdamW(optim_groups, lr=learning_rate, betas=betas, eps=eps) - elif name == OptimizerType.adamw: - return torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, eps=eps) - else: - raise NotImplementedError(f"Not sure how to build optimizer '{name}'") diff --git a/dolma/tokenizer.py b/dolma/tokenizer.py index 7512ea8b0..41b896151 100644 --- a/dolma/tokenizer.py +++ b/dolma/tokenizer.py @@ -127,8 +127,8 @@ def encode_batch(self, inputs: List[str], add_special_tokens: bool = True) -> Li return all_input_ids - def decode(self, token_ids: List[int]) -> str: + def decode(self, token_ids: List[int], skip_special_tokens: bool = True) -> str: """ Decode a list of token IDs to a string. """ - return self.base_tokenizer.decode(token_ids) + return self.base_tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens) diff --git a/mypy.ini b/mypy.ini deleted file mode 100644 index 2bc5ea687..000000000 --- a/mypy.ini +++ /dev/null @@ -1,6 +0,0 @@ -[mypy] -ignore_missing_imports = true -no_site_packages = true - -[mypy-tests.*] -strict_optional = false diff --git a/pyproject.toml b/pyproject.toml index 22468b9a0..7bd9fc2b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,8 +1,6 @@ [tool.black] line-length = 115 - include = '\.pyi?$' - exclude = ''' ( __pycache__ @@ -27,3 +25,28 @@ exclude = ["pretrain_data/"] [build-system] requires = ["setuptools", "wheel"] build-backend = "setuptools.build_meta" + +[tool.mypy] +ignore_missing_imports = true +no_site_packages = true + +[[tool.mypy.overrides]] +module = "tests.*" +strict_optional = false + +[tool.pytest.ini_options] +testpaths = "tests/" +python_classes = [ + "Test*", + "*Test", +] +log_format = "%(asctime)s - %(levelname)s - %(name)s - %(message)s" +log_level = "DEBUG" +markers = [ + "gpu", +] +filterwarnings = [ + 'ignore::FutureWarning:huggingface_hub\.file_download', + 'ignore::DeprecationWarning:pkg_resources', + 'ignore::DeprecationWarning:google\.rpc', +] diff --git a/pytest.ini b/pytest.ini deleted file mode 100644 index bbf35b764..000000000 --- a/pytest.ini +++ /dev/null @@ -1,8 +0,0 @@ -[pytest] -testpaths = tests/ -python_classes = Test* *Test -log_format = %(asctime)s - %(levelname)s - %(name)s - %(message)s -log_level = DEBUG -markers = - gpu: marks tests that need GPUs -filterwarnings = diff --git a/requirements.txt b/requirements.txt index 981f36ac6..90816083c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,9 @@ # NOTE: when upgrading requirements here you may have to rebuild and push some # Docker images. See each Dockerfile for details on how to do that. numpy -torch +torch>=2.0 einops -# bug with 0.13.0, see https://github.com/mosaicml/composer/issues/2030 -mosaicml>=0.13.1 +mosaicml@git+https://github.com/allenai/composer.git@Torch2 torchmetrics tokenizers click diff --git a/scripts/beaker_interactive.sh b/scripts/beaker_interactive.sh index 6523cc0b3..20aec40b9 100644 --- a/scripts/beaker_interactive.sh +++ b/scripts/beaker_interactive.sh @@ -21,7 +21,7 @@ export WANDB_API_KEY=$(beaker secret read WANDB_API_KEY) # Create and activate environment. conda create -y -n LLM python=3.10 conda activate LLM -pip install --upgrade pip +echo "conda activate LLM" >> ~/.bashrc # Install GitHub CLI. conda install -y gh -c conda-forge @@ -30,15 +30,14 @@ conda install -y gh -c conda-forge gh auth setup-git # Install PyTorch. -conda install -y pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia +conda install -y pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia +# Install flash attn (and triton dependency) from our pre-built wheel. # We need cuda dev for the old version of triton. # NOTE: once we're able to upgrade triton to >=2.0, we can remove this. -conda install -y -c nvidia cuda-libraries-dev=11.7 cuda-nvcc=11.7 -export CUDA_HOME="$CONDA_PREFIX" - -# Install flash attn (and triton dependency) from our pre-built wheel. -pip install triton==2.0.0.dev20221202 https://storage.googleapis.com/ai2-python-wheels/flash_attn/flash_attn-0.2.8%2Bcu117torch1.13.1-cp310-cp310-linux_x86_64.whl +# conda install -y -c nvidia cuda-libraries-dev=11.8 cuda-nvcc=11.8 +# export CUDA_HOME="$CONDA_PREFIX" +# pip install triton==2.0.0.dev20221202 https://storage.googleapis.com/ai2-python-wheels/flash_attn/flash_attn-0.2.8%2Bcu118torch2.0.0-cp310-cp310-linux_x86_64.whl # Check for GPUs. python -c 'import torch; print(f"GPUs available: {torch.cuda.device_count()}")' diff --git a/scripts/train.py b/scripts/train.py index c54473b11..3b276db4b 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -21,14 +21,12 @@ import logging import os import sys -from typing import List +from typing import List, cast import torch -from dolma import TrainConfig -from dolma.data import build_dataloader +from dolma import Dolma, TrainConfig from dolma.exceptions import DolmaCliError -from dolma.optim import build_optimizer from dolma.util import clean_opt, log_extra_field, prepare_cli_environment log = logging.getLogger(__name__) @@ -36,7 +34,7 @@ def main(cfg: TrainConfig) -> None: from composer import Trainer - from composer.callbacks import SpeedMonitor + from composer.callbacks import LRMonitor, OptimizerMonitor, SpeedMonitor from composer.core import Callback from composer.loggers import WandBLogger from composer.loggers.logger_destination import LoggerDestination @@ -44,13 +42,17 @@ def main(cfg: TrainConfig) -> None: from composer.utils.dist import get_node_rank from dolma.composer import ( - ComposerDolmaGPT, + ComposerDolmaLM, DolmaConsoleLogger, build_algorithm, + build_dataloader, + build_optimizer, build_scheduler, update_batch_size_info, ) + cfg.model.precision = cfg.precision + if get_node_rank() == 0: log.info("Configuration:") log.info(cfg) @@ -75,16 +77,24 @@ def main(cfg: TrainConfig) -> None: f"for global batch size of {cfg.global_train_batch_size}" ) - # Model. - model = ComposerDolmaGPT(cfg.model) + # Initialize the model. + dolma_model = Dolma(cfg.model) if get_node_rank() == 0: - log.info(f"Total number of parameters: {model.model.num_params():,d}") + log.info(f"Total number of parameters: {dolma_model.num_params():,d}") log.info( - f"Number of non-embedding parameters: {model.model.num_params(include_embedding=False):,d}", + f"Number of non-embedding parameters: {dolma_model.num_params(include_embedding=False):,d}", ) + # Compile it if necessary. + if cfg.compile is not None: + compile_kwargs = cfg.compile.asdict() + if compile_kwargs.get("fullgraph") is None: + compile_kwargs["fullgraph"] = cfg.fsdp_config is None + # As far as duck typing is concerned, this is still a Dolma object. + dolma_model = cast(Dolma, torch.compile(dolma_model, **compile_kwargs)) + # Optimizer. - optimizer = build_optimizer(model.model, **cfg.optimizer.asdict()) + optimizer = build_optimizer(dolma_model, **cfg.optimizer.asdict()) # Scheduler. scheduler = build_scheduler(cfg.scheduler) @@ -93,21 +103,33 @@ def main(cfg: TrainConfig) -> None: train_loader = build_dataloader(cfg, cfg.device_train_batch_size) # Algorithms. - algorithms = [build_algorithm(name, algorithm_cfg) for name, algorithm_cfg in (cfg.algorithms or {}).items()] + algorithms = [ + build_algorithm(name, algorithm_cfg) + for name, algorithm_cfg in (cfg.algorithms or {}).items() + if algorithm_cfg is not None + ] # Callbacks. - callbacks: List[Callback] = [SpeedMonitor(**cfg.speed_monitor.asdict())] + callbacks: List[Callback] = [ + SpeedMonitor(**cfg.speed_monitor.asdict()), + LRMonitor(), + OptimizerMonitor(log_optimizer_metrics=False), + ] # Loggers. loggers: List[LoggerDestination] = [DolmaConsoleLogger(log_interval=cfg.console_log_interval)] if cfg.wandb is not None: loggers.append(WandBLogger(init_kwargs={"config": cfg.asdict(exclude=["wandb"])}, **cfg.wandb.asdict())) + # Wrap model into composer model. + composer_model = ComposerDolmaLM(dolma_model) + del dolma_model + # Trainer. trainer = Trainer( run_name=cfg.run_name, seed=cfg.seed, - model=model, + model=composer_model, train_dataloader=train_loader, optimizers=optimizer, schedulers=scheduler, diff --git a/test_fixtures/train_tiny.yaml b/test_fixtures/train_tiny.yaml index 50f621194..43e245aca 100644 --- a/test_fixtures/train_tiny.yaml +++ b/test_fixtures/train_tiny.yaml @@ -25,6 +25,8 @@ data: paths: - "/tmp/c4-sample.npy" persistent_workers: false + num_workers: 0 + prefetch_factor: null tokenizer: identifier: "gpt2" save_overwrite: true diff --git a/tests/model_test.py b/tests/model_test.py index 68ad88440..f834fcbd5 100644 --- a/tests/model_test.py +++ b/tests/model_test.py @@ -2,21 +2,28 @@ import torch from torch.nn import CrossEntropyLoss -from dolma import DolmaGPT, ModelConfig, Tokenizer, TrainConfig +from dolma import BlockType, Dolma, ModelConfig, Tokenizer, TrainConfig +from dolma.composer import build_optimizer from dolma.data import DataCollator -from dolma.optim import build_optimizer @pytest.mark.parametrize( - "alibi, flash_attn, cuda, dtype", + "alibi, rope, flash_attn, block_type, cuda, dtype", [ - pytest.param(True, False, False, torch.bfloat16, id="alibi-emb-cpu-bf16"), - pytest.param(False, False, False, torch.bfloat16, id="posit-emb-cpu-bf16"), - pytest.param(True, False, False, torch.float32, id="alibi-emb-cpu-f32"), - pytest.param(False, False, False, torch.float32, id="posit-emb-cpu-f32"), + pytest.param(True, False, False, BlockType.sequential, False, torch.bfloat16, id="alibi-emb-cpu-bf16"), + pytest.param( + True, False, False, BlockType.parallel, False, torch.bfloat16, id="alibi-emb-parallel-block-cpu-bf16" + ), + pytest.param(False, False, False, BlockType.sequential, False, torch.bfloat16, id="posit-emb-cpu-bf16"), + pytest.param(True, False, False, BlockType.sequential, False, torch.float32, id="alibi-emb-cpu-f32"), + pytest.param(False, False, False, BlockType.sequential, False, torch.float32, id="posit-emb-cpu-f32"), + pytest.param(False, True, False, BlockType.sequential, False, torch.bfloat16, id="rope-emb-cpu-bf16"), + pytest.param(False, True, False, BlockType.sequential, False, torch.float32, id="rope-emb-cpu-f32"), pytest.param( True, False, + False, + BlockType.sequential, True, torch.bfloat16, id="alibi-emb-cuda-bf16", @@ -26,22 +33,26 @@ ), ), pytest.param( + True, False, False, + BlockType.parallel, True, torch.bfloat16, - id="posit-emb-cuda-bf16", + id="alibi-emb-parallel-block-cuda-bf16", marks=( pytest.mark.gpu, pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Requires CUDA device"), ), ), pytest.param( + False, True, - True, + False, + BlockType.sequential, True, torch.bfloat16, - id="alibi-emb-flash-cuda-bf16", + id="rope-emb-cuda-bf16", marks=( pytest.mark.gpu, pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Requires CUDA device"), @@ -49,29 +60,35 @@ ), pytest.param( False, - True, + False, + False, + BlockType.sequential, True, torch.bfloat16, - id="posit-emb-flash-cuda-bf16", + id="posit-emb-cuda-bf16", marks=( pytest.mark.gpu, pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Requires CUDA device"), ), ), pytest.param( + False, + False, True, + BlockType.sequential, True, - True, - torch.float16, - id="alibi-emb-flash-cuda-f16", + torch.bfloat16, + id="posit-emb-flash-cuda-bf16", marks=( pytest.mark.gpu, pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Requires CUDA device"), ), ), pytest.param( + False, False, True, + BlockType.sequential, True, torch.float16, id="posit-emb-flash-cuda-f16", @@ -83,14 +100,24 @@ ], ) def test_forward( - train_config: TrainConfig, tokenizer: Tokenizer, alibi: bool, flash_attn: bool, cuda: bool, dtype + train_config: TrainConfig, + tokenizer: Tokenizer, + alibi: bool, + rope: bool, + flash_attn: bool, + block_type: BlockType, + cuda: bool, + dtype, ): torch.manual_seed(0) + torch.use_deterministic_algorithms(True) train_config.model.alibi = alibi + train_config.model.rope = rope train_config.model.flash_attention = flash_attn if flash_attn: train_config.model.attention_dropout = 0.0 + train_config.model.block_type = block_type if cuda: train_config.model.init_device = "cuda" else: @@ -98,7 +125,7 @@ def test_forward( use_amp = dtype in {torch.float16, torch.bfloat16} - model = DolmaGPT(train_config.model).eval() + model = Dolma(train_config.model).eval() input1 = tokenizer.encode("My name is DOLMA!") input2 = tokenizer.encode("I'm a delightful large open language model :)") @@ -161,21 +188,6 @@ def test_forward( pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Requires CUDA device"), ), ), - pytest.param( - True, - True, - True, - torch.bfloat16, - id="alibi-emb-flash-cuda-bf16", - marks=( - pytest.mark.gpu, - pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Requires CUDA device"), - pytest.mark.skipif( - torch.cuda.device_count() < 1 or "A100" not in torch.cuda.get_device_name(), - reason="Requires A100 GPU type", - ), - ), - ), pytest.param( False, True, @@ -210,7 +222,7 @@ def test_backward( else: train_config.model.init_device = "cpu" - model = DolmaGPT(train_config.model).train() + model = Dolma(train_config.model).train() with torch.autocast( device_type="cuda" if cuda else "cpu", enabled=use_amp, dtype=None if not use_amp else dtype @@ -242,4 +254,4 @@ def test_backward( def test_build_optimizer(model_config: ModelConfig): - build_optimizer(DolmaGPT(model_config)) + build_optimizer(Dolma(model_config))