Skip to content

Commit

Permalink
Merge pull request #61 from allenai/Torch2
Browse files Browse the repository at this point in the history
Torch 2.0
  • Loading branch information
dirkgr authored Apr 5, 2023
2 parents caf1f32 + 5613711 commit ef3e157
Show file tree
Hide file tree
Showing 27 changed files with 969 additions and 528 deletions.
12 changes: 5 additions & 7 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ on:
pull_request:
branches:
- main
- Torch2
push:
branches:
- main
Expand Down Expand Up @@ -108,7 +109,7 @@ jobs:
timeout-minutes: 15
env:
BEAKER_TOKEN: ${{ secrets.BEAKER_TOKEN }}
BEAKER_IMAGE: dolma-test
BEAKER_IMAGE: dolma-torch2-test
BEAKER_WORKSPACE: ai2/llm-testing
steps:
- name: Determine current commit SHA (pull request)
Expand All @@ -133,26 +134,23 @@ jobs:
image:
beaker: ${{ env.BEAKER_IMAGE }}
context:
priority: preemptible
priority: normal
resources:
gpuCount: 1
constraints:
cluster:
- ai2/general-cirrascale
- ai2/general-cirrascale-a100-80g-ib
- ai2/allennlp-cirrascale
- ai2/aristo-cirrascale
- ai2/mosaic-cirrascale
- ai2/mosaic-cirrascale-a100
- ai2/prior-cirrascale
- ai2/s2-cirrascale
envVars:
- name: COMMIT_SHA
value: ${{ env.COMMIT_SHA }}
- name: GITHUB_TOKEN
value: ${{ secrets.GITHUB_TOKEN }}
- name: CUDA_LAUNCH_BLOCKING
value: "1"
- name: CUBLAS_WORKSPACE_CONFIG
value: ":16:8"
- name: TOKENIZERS_PARALLELISM
value: "false"
command: ["/entrypoint.sh", "pytest", "-v", "-m", "gpu", "tests/"]
Expand Down
46 changes: 41 additions & 5 deletions LOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,47 @@
Experiment Log
==============

2023-04-03
----------

We added the option to decouple the MLP and Attention computations as in the PaLM architecture.
That is, within each transformer block we compute `MLP(LN(x)) + Attention(LN(x))` instead of `MLP(LN(x + Attention(LN(x))))` (ignoring some skip connections).
This allows to increase throughput because we can fuse the separate feed-forward and attention input projections into a single linear layer.
We also experimented with [fusing the output projections](https://github.com/allenai/LLM/pull/79) into a single linear layer but that didn't help, possibly due to the overhead of concatenating the feed-forward and attention activations together.

2023-03-28
----------

We've investigated a number ways to optimize training throughput in terms of tokens per second and MFU (model flop utilization). This is a list of all of the optimizations that have worked so far, ranked by how much of speedup they gave on a 1.2b param model:

1. Using FlashAttention via PyTorch's built-in `scaled_dot_product_attention` function. This resulted in a ~12% speedup over the default attention implementation while also reducing GPU memory utilization.

Unfortunately ALiBi can't be used with FlashAttention at the moment, so the best option if we want to use relative positional encodings is probably RoPE (which can be used with FlashAttention). In general RoPE is slower than ALiBi but when combined with FlashAttention it's faster. Of course ALiBi + FlashAttention would be ideal.

1. Setting embedding/vocab size to a multiple of 128. E.g. the actual vocab size is 50257, but we force the embedding size to be 50304. This resulted in an ~11% speedup.
1. Using low-precision LayerNorm when **not** using `torch.compile()`. This resulted in a speedup of ~10%, but it actually slows throughput when using a compiled model. This probably has to do with manually casting tensors to different data types, which cause more breaks in the graph.
1. Compiling the model via `torch.compile()` with the default mode. This resulted in a ~7% speedup without increasing (and in some cases decreasing) GPU memory utilization.

The other compile modes ("reduce-overhead" and "max-autotune") were not as fast and required substantially more GPU memory.

Compiling as a "fullgraph" also improves throughput even further except when using FSDP since FSDP forces breaks in the graph.
1. Tweaking the FSDP settings to use "PURE" mixed precision, limit all gathers, and use non-reentrant activation checkpointing resulted in a 1-2% speedup.

Using the best compatible combination of the above settings (so everything except #3) gets us close to 60% MFU with the 1.2b model. That's really good!

For more details, see:
- [Benchmarking the performance of PyTorch's new `compile()` and built-in FlashAttention.](https://wandb.ai/ai2-llm/petew-torch2-benchmarks/reports/PyTorch-2-0-benchmarks--VmlldzozODQyMDY5?accessToken=2fh801xe265n5xx7juphb1xnx8itvls8g7nrqsjdd4ja0xlks7kaozue94z2mez3)
- [Benchmarking the cost of using RoPE](https://wandb.ai/ai2-llm/rope-benchmarks/reports/Benchmarking-RoPE--VmlldzozODQ1MjMz)
- [Benchmarking the performance of `compile()` with FSDP](https://wandb.ai/ai2-llm/fsdp-compile-benchmarks)
- [Benchmarking low precision LayerNorm](https://api.wandb.ai/links/ai2-llm/9favfpnh)


2023-03-15
----------

The cluster is down for maintenance, so we're just queueing up some features we want to run. We also used the LUMI downtime to build a better logging feature. When running 1000s of nodes in a cluster, it's difficult to get logs that make sense. We're sending our logs to third-party logging provider [logz.io](https://logz.io). It's basic, but it gets the job done.


2023-03-14
----------

Expand All @@ -16,8 +57,3 @@ Findings:
I'm not sure what that buys us, and it's one extra component in the mix, so I didn't do it that way.
* Automatic restarts work. One run got killed and automatically restarted.
It is great that restarts work, but somewhat worrisome that we're already sampling this behavior after less than 45 minutes of runtime on only one node.

2023-03-15
----------

The cluster is down for maintenance, so we're just queueing up some features we want to run. We also used the LUMI downtime to build a better logging feature. When running 1000s of nodes in a cluster, it's difficult to get logs that make sense. We're sending our logs to third-party logging provider [logz.io](https://logz.io). It's basic, but it gets the job done.
6 changes: 3 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# If you update this, also update BEAKER_IMAGE in .github/workflows/main.yml
IMAGE_NAME_BASE = dolma
IMAGE_NAME_BASE = dolma-torch2
# If you update this, also update BEAKER_WORKSPACE in .github/workflows/main.yml
BEAKER_WORKSPACE = "ai2/llm-testing"
BEAKER_WORKSPACE = ai2/llm-testing

BEAKER_USER = $(shell beaker account whoami --format=json | jq -r '.[0].name')
GANTRY_IMAGE = $(shell beaker workspace images $(BEAKER_WORKSPACE) --format=json | jq -r -c '.[] | select( .name == "$(IMAGE_NAME_BASE)-gantry" ) | .fullName')
Expand Down Expand Up @@ -58,7 +58,7 @@ show-beaker-workspace :
gantry-test :
gantry run \
--workspace "$(BEAKER_WORKSPACE)" \
--priority "preemptible" \
--priority "normal" \
--beaker-image "$(GANTRY_IMAGE)" \
--gpus 1 \
--description "Test run" \
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ gantry run \
--nfs \
--priority preemptible \
--gpus 8 \
--beaker-image dolma-gantry \
--beaker-image dolma-torch2-gantry \
--cluster 'ai2/*-cirrascale' \
--allow-dirty \
-- composer scripts/train.py configs/1.2b-c4.yaml
Expand All @@ -36,7 +36,7 @@ Train the 70B model on c4 with gantry across multiple nodes:
gantry run \
--workspace ai2/llm-testing \
--priority "high" \
--beaker-image dolma-gantry \
--beaker-image dolma-torch2-gantry \
--cluster ai2/general-cirrascale-a100-80g-ib \
--gpus 8 \
--nfs \
Expand Down
18 changes: 14 additions & 4 deletions configs/1.2b-c4.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,24 @@ model:
alibi_bias_max: 8.0
attention_dropout: 0.0
attention_layer_norm: true
layer_norm_type: default # if not compiling, use 'low_precision'
activation_type: swiglu
residual_dropout: 0.0
embedding_dropout: 0.0
max_sequence_length: 1024
vocab_size: 50257
embedding_size: 50304
eos_token_id: 50256
pad_token_id: 50256
init_device: meta
init_device: null
init_std: 0.02

compile:
mode: default
fullgraph: null

optimizer:
name: decoupled_adamw
name: decoupled_lionw
learning_rate: 2.0e-4
weight_decay: 1.2e-4
betas:
Expand Down Expand Up @@ -71,13 +78,16 @@ device_eval_batch_size: null

n_gpus: null

precision: null
precision: amp_bf16

fsdp_config:
sharding_strategy: FULL_SHARD
mixed_precision: DEFAULT
mixed_precision: PURE
activation_checkpointing: false
activation_cpu_offload: false
activation_checkpointing_reentrant: false
limit_all_gathers: true
use_orig_params: true # needed to work with compile
verbose: false

speed_monitor:
Expand Down
15 changes: 12 additions & 3 deletions configs/300m-c4.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,24 @@ model:
flash_attention: false
attention_dropout: 0.0
attention_layer_norm: false
layer_norm_type: default # if not compiling, use 'low_precision'
residual_dropout: 0.0
embedding_dropout: 0.0
max_sequence_length: 1024
include_bias: true
vocab_size: 50257
embedding_size: 50304
eos_token_id: 50256
pad_token_id: 50256
init_device: null
init_std: 0.02

compile:
mode: default
fullgraph: null

optimizer:
name: decoupled_adamw
name: decoupled_lionw
learning_rate: 3.0e-4
weight_decay: 1.2e-4
betas:
Expand Down Expand Up @@ -72,13 +78,16 @@ device_eval_batch_size: null

n_gpus: null

precision: null
precision: amp_bf16

fsdp_config:
sharding_strategy: FULL_SHARD
mixed_precision: DEFAULT
mixed_precision: PURE
activation_checkpointing: false
activation_cpu_offload: false
activation_checkpointing_reentrant: false
limit_all_gathers: true
use_orig_params: true # needed to work with compile
verbose: false

wandb:
Expand Down
15 changes: 13 additions & 2 deletions configs/70b-c4.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,25 @@ model:
flash_attention: false
attention_dropout: 0.0 # has to be 0 if using flash attn
attention_layer_norm: true
block_type: parallel
layer_norm_type: default # if not compiling, use 'low_precision'
activation_type: swiglu
residual_dropout: 0.0
embedding_dropout: 0.0
max_sequence_length: 2048
vocab_size: 50257
embedding_size: 50304
eos_token_id: 50256
pad_token_id: 50256
init_device: meta
init_std: 0.02

compile:
mode: default
fullgraph: null

optimizer:
name: decoupled_adamw
name: decoupled_lionw
learning_rate: 8.0e-5
weight_decay: 1.2e-4
betas:
Expand Down Expand Up @@ -76,9 +84,12 @@ precision: amp_bf16

fsdp_config:
sharding_strategy: FULL_SHARD
mixed_precision: DEFAULT # could be PURE with flash attn
mixed_precision: PURE
activation_checkpointing: true
activation_cpu_offload: false
activation_checkpointing_reentrant: false
limit_all_gathers: true
use_orig_params: true # needed to work with compile
verbose: false

speed_monitor:
Expand Down
4 changes: 2 additions & 2 deletions configs/tiny.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ model:
init_std: 0.02

optimizer:
name: decoupled_adamw
name: decoupled_lionw
learning_rate: 3.0e-4
weight_decay: 1.2e-4
betas:
Expand Down Expand Up @@ -72,7 +72,7 @@ device_eval_batch_size: null

n_gpus: null

precision: null
precision: amp_bf16

fsdp_config: null

Expand Down
13 changes: 6 additions & 7 deletions docker/Dockerfile.base
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
# Defines a CUDA-enabled Docker image suitable for installing all dependencies
# to this project.

FROM ghcr.io/allenai/pytorch:1.13.1-cuda11.7-python3.10
FROM ghcr.io/allenai/pytorch:2.0.0-cuda11.8-python3.10

# Install flash attn (and triton dependency) from our pre-built wheel.
# We need cuda dev for the old version of triton.
# NOTE: once we're able to upgrade triton to >=2.0, we can remove this.
RUN /opt/conda/bin/conda install -c nvidia cuda-libraries-dev

# Install flash attn (and triton dependency) from our pre-built wheel.
RUN /opt/conda/bin/pip install --no-cache-dir \
triton==2.0.0.dev20221202 \
https://storage.googleapis.com/ai2-python-wheels/flash_attn/flash_attn-0.2.8%2Bcu117torch1.13.1-cp310-cp310-linux_x86_64.whl
# RUN /opt/conda/bin/conda install -c nvidia cuda-libraries-dev
# RUN /opt/conda/bin/pip install --no-cache-dir \
# triton==2.0.0.dev20221202 \
# https://storage.googleapis.com/ai2-python-wheels/flash_attn/flash_attn-0.2.8%2Bcu118torch2.0.0-cp310-cp310-linux_x86_64.whl

ENV CUDA_HOME=/opt/conda
2 changes: 1 addition & 1 deletion docker/Dockerfile.gantry
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# To build and push the image to Beaker, run 'make gantry-image'.
# To test the image after pushing to Beaker, run 'make gantry-test'.

FROM dolma-base
FROM dolma-torch2-base

WORKDIR /stage

Expand Down
39 changes: 31 additions & 8 deletions docker/Dockerfile.lumi
Original file line number Diff line number Diff line change
@@ -1,16 +1,39 @@
FROM rocm/dev-ubuntu-22.04:5.4-complete
FROM ubuntu:latest

ENV DEBIAN_FRONTEND=noninteractive
ENV LC_ALL=C.UTF-8
ENV LANG=C.UTF-8

# Install various softwares
RUN apt-get update
RUN apt-get upgrade -y
RUN apt-get install -y python-is-python3 git autoconf python3-dev git vim libtool openjdk-8-jdk-headless xvfb fish build-essential wget parallel s3cmd awscli rocm-libs rccl
RUN apt-get install -y \
python-is-python3 \
python3-dev \
libpython3-all-dev \
python-dev-is-python3 \
python3-pip \
build-essential \
git \
autoconf \
libtool \
llvm \
vim \
fish \
wget \
parallel \
s3cmd \
awscli \
htop \
wget \
fish

# Fix for Java trying to find assistive techs in headless java
# https://askubuntu.com/questions/695560/assistive-technology-not-found-awterror
RUN sed -i -e '/^assistive_technologies=/s/^/#/' /etc/java-8-openjdk/accessibility.properties
# Install ROCm
RUN wget https://repo.radeon.com/amdgpu-install/5.4.3/ubuntu/jammy/amdgpu-install_5.4.50403-1_all.deb && \
apt-get install -y ./amdgpu-install_5.4.50403-1_all.deb && \
amdgpu-install -y --accept-eula --usecase=rocm --no-dkms && \
rm ./amdgpu-install_5.4.50403-1_all.deb && \
apt-get install -y rccl rocm-libs

# Install MPICH
ENV MPICH_VERSION="3.1.4"
Expand All @@ -34,8 +57,7 @@ ENV LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH

# Install torch
RUN pip install --upgrade pip
RUN pip install --no-cache-dir "torch<2.0" torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.2
#RUN pip install --pre --no-cache-dir "torch<2.0" torchvision torchaudio torchtext --extra-index-url https://download.pytorch.org/whl/nightly/rocm5.3
RUN pip install --no-cache-dir torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.4.2

# Install DeepSpeed
RUN pip install --no-cache-dir mpi4py
Expand All @@ -49,9 +71,10 @@ RUN cd /opt && \
COPY requirements.txt requirements.txt
RUN pip install --no-cache-dir -r requirements.txt
RUN pip install --no-cache-dir py-spy
RUN pip install wandb --upgrade
RUN pip install --no-cache-dir wandb --upgrade

# Cleanup
RUN apt-get autoremove
RUN rm -rf /opt/mpich-3.1.4 /opt/aws-ofi-rccl /opt/DeepSpeed
RUN apt-get clean
RUN pip cache purge
2 changes: 1 addition & 1 deletion docker/Dockerfile.test
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#
# To build and push the image to Beaker, run 'make test-image'.

FROM dolma-base
FROM dolma-torch2-base

COPY scripts/test_entrypoint.sh /entrypoint.sh
RUN chmod +x /entrypoint.sh
Expand Down
Loading

0 comments on commit ef3e157

Please sign in to comment.