-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Torchao float8 training #3348
Open
muellerzr
wants to merge
9
commits into
main
Choose a base branch
from
torchao-float8
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Torchao float8 training #3348
Changes from 5 commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
0b5a9c1
Bookmark
muellerzr b2cce71
bookmark
muellerzr e1a1304
Add torchao base example
muellerzr be210db
Currently broken
muellerzr fbeb5a7
Clean
muellerzr 3820c40
DDP varient working
muellerzr 4dfa816
FSDP as well
muellerzr 04c9f56
Works for all but zero3
muellerzr f805876
Bookmark: currently zero3 is underperforming
muellerzr File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
FROM nvcr.io/nvidia/pytorch:24.07-py3 | ||
|
||
RUN pip install transformers evaluate datasets | ||
RUN git clone https://github.com/huggingface/accelerate.git | ||
|
||
RUN cd accelerate && \ | ||
pip install -e . && \ | ||
cd benchmarks/fp8 | ||
|
||
RUN /bin/bash | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# FP8 Benchmarks | ||
|
||
Comparing and running [torchao](https://github.com/pytorch/ao/tree/main/torchao/float8) FP8 with accelerate | ||
|
||
## Overview | ||
|
||
This repo provides scripts which compare native `torchao` model training against `accelerate`'s own integration. Each modeling type is segmented out via a script, supporting the following: | ||
|
||
* Single GPU training (`non_distributed.py`) | ||
* Multi-GPU training via DistributedDataParallelism (`ddp.py`) | ||
* Fully Sharded Data Parallelism (`fsdp.py`) | ||
* DeepSpeed ZeRO 1-3 (`deepspeed.py`) | ||
|
||
To run them, it's recommended to use a docker image (see the attached `Dockerfile`) and not install `torchao` manually. | ||
|
||
## Running: | ||
|
||
There are official Docker images located at `huggingface/accelerate:gpu-fp8-torchao-nightly` which can be used. | ||
|
||
You can run all scripts using the core `accelerate launch` command without any `accelerate config` being needed. | ||
|
||
For single GPU, run it via `python`: | ||
|
||
```bash | ||
python non_distributed.py | ||
``` | ||
|
||
For the rest, run it via `accelerate launch`: | ||
|
||
```bash | ||
accelerate launch ddp.py # or distrib_deepspeed.py, ddp.py | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
""" | ||
This script tests to ensure that `accelerate` performs at the same level as raw `TransformersEngine`. | ||
|
||
This particular script verifies this for DDP training. | ||
""" | ||
|
||
import evaluate | ||
import torch | ||
import transformer_engine.common.recipe as te_recipe | ||
import transformer_engine.pytorch as te | ||
from fp8_utils import evaluate_model, get_named_parameters, get_training_utilities | ||
from torch.nn.parallel import DistributedDataParallel as DDP | ||
from transformer_engine.common.recipe import DelayedScaling | ||
|
||
from accelerate import Accelerator | ||
from accelerate.state import AcceleratorState | ||
from accelerate.utils import FP8RecipeKwargs, set_seed | ||
from accelerate.utils.transformer_engine import convert_model | ||
|
||
|
||
MODEL_NAME = "bert-base-cased" | ||
METRIC = evaluate.load("glue", "mrpc") | ||
|
||
|
||
def train_baseline(): | ||
set_seed(42) | ||
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME) | ||
accelerator = Accelerator() | ||
device = accelerator.device | ||
model.to(device) | ||
|
||
# Convert the model to TE | ||
old_named_params = get_named_parameters(model) | ||
|
||
with torch.no_grad(): | ||
convert_model(model) | ||
|
||
FP8_RECIPE_KWARGS = {"fp8_format": te_recipe.Format.HYBRID, "amax_history_len": 32, "amax_compute_algo": "max"} | ||
fp8_recipe = DelayedScaling(**FP8_RECIPE_KWARGS) | ||
|
||
new_named_params = get_named_parameters(model) | ||
|
||
# Convert the model to DDP | ||
device_ids, output_device = [accelerator.local_process_index], accelerator.local_process_index | ||
model = DDP(model, device_ids=device_ids, output_device=output_device) | ||
|
||
mapping = {p: new_named_params[n] for n, p in old_named_params.items()} | ||
for param_group in optimizer.param_groups: | ||
param_group["params"] = [mapping[p] for p in param_group["params"]] | ||
|
||
base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) | ||
model.train() | ||
|
||
for _ in range(2): | ||
for batch in train_dataloader: | ||
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): | ||
with torch.autocast(device_type="cuda", dtype=torch.bfloat16): | ||
batch = batch.to(device) | ||
outputs = model(**batch) | ||
loss = outputs.loss | ||
loss.backward() | ||
optimizer.step() | ||
optimizer.zero_grad() | ||
lr_scheduler.step() | ||
|
||
trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) | ||
|
||
assert ( | ||
trained_model_results["accuracy"] > base_model_results["accuracy"] | ||
), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}' | ||
assert ( | ||
trained_model_results["f1"] > base_model_results["f1"] | ||
), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}' | ||
|
||
return base_model_results, trained_model_results | ||
|
||
|
||
def train_integration(): | ||
FP8_RECIPE_KWARGS = {"fp8_format": "HYBRID", "amax_history_len": 32, "amax_compute_algo": "max"} | ||
kwargs_handlers = [FP8RecipeKwargs(backend="TE", **FP8_RECIPE_KWARGS)] | ||
AcceleratorState()._reset_state(True) | ||
accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=kwargs_handlers) | ||
set_seed(42) | ||
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities( | ||
MODEL_NAME, accelerator=accelerator | ||
) | ||
|
||
model, optimizer = accelerator.prepare(model, optimizer) | ||
base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) | ||
model.train() | ||
|
||
for _ in range(2): | ||
for batch in train_dataloader: | ||
outputs = model(**batch) | ||
loss = outputs.loss | ||
accelerator.backward(loss) | ||
optimizer.step() | ||
optimizer.zero_grad() | ||
lr_scheduler.step() | ||
|
||
trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) | ||
|
||
assert ( | ||
trained_model_results["accuracy"] > base_model_results["accuracy"] | ||
), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}' | ||
assert ( | ||
trained_model_results["f1"] > base_model_results["f1"] | ||
), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}' | ||
|
||
return base_model_results, trained_model_results | ||
|
||
|
||
if __name__ == "__main__": | ||
baseline_not_trained, baseline_trained = train_baseline() | ||
accelerator_not_trained, accelerator_trained = train_integration() | ||
|
||
assert ( | ||
baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] | ||
), f'Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' | ||
assert ( | ||
baseline_not_trained["f1"] == accelerator_not_trained["f1"] | ||
), f'F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' | ||
assert ( | ||
baseline_trained["accuracy"] == accelerator_trained["accuracy"] | ||
), f'Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' | ||
assert ( | ||
baseline_trained["f1"] == accelerator_trained["f1"] | ||
), f'F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' | ||
|
||
torch.distributed.destroy_process_group() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,190 @@ | ||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
""" | ||
This script tests to ensure that `accelerate` performs at the same level as raw `TransformersEngine`. | ||
|
||
This particular script verifies this for DDP training. | ||
""" | ||
|
||
from unittest.mock import patch | ||
|
||
import deepspeed | ||
import evaluate | ||
import torch | ||
import transformer_engine.common.recipe as te_recipe | ||
import transformer_engine.pytorch as te | ||
from fp8_utils import evaluate_model, get_named_parameters, get_training_utilities | ||
from transformer_engine.common.recipe import DelayedScaling | ||
|
||
from accelerate import Accelerator, DeepSpeedPlugin | ||
from accelerate.state import AcceleratorState | ||
from accelerate.utils import FP8RecipeKwargs, set_seed | ||
from accelerate.utils.transformer_engine import convert_model | ||
|
||
|
||
MODEL_NAME = "bert-base-cased" | ||
METRIC = evaluate.load("glue", "mrpc") | ||
|
||
|
||
def train_baseline(zero_stage: int = 1): | ||
# This forces transformers to think Zero-3 Init should be used | ||
with patch("transformers.integrations.deepspeed.is_deepspeed_zero3_enabled") as mock: | ||
mock.return_value = zero_stage == 3 | ||
set_seed(42) | ||
|
||
accelerator = Accelerator() | ||
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities( | ||
MODEL_NAME, accelerator=accelerator | ||
) | ||
|
||
# Convert the model to TE | ||
old_named_params = get_named_parameters(model) | ||
|
||
with torch.no_grad(): | ||
convert_model(model) | ||
new_named_params = get_named_parameters(model) | ||
|
||
mapping = {p: new_named_params[n] for n, p in old_named_params.items()} | ||
for param_group in optimizer.param_groups: | ||
param_group["params"] = [mapping[p] for p in param_group["params"]] | ||
|
||
FP8_RECIPE_KWARGS = {"fp8_format": te_recipe.Format.HYBRID, "amax_history_len": 32, "amax_compute_algo": "max"} | ||
fp8_recipe = DelayedScaling(**FP8_RECIPE_KWARGS) | ||
|
||
import numpy as np | ||
|
||
config = { | ||
"train_batch_size": 32, | ||
"train_micro_batch_size_per_gpu": 16, | ||
"gradient_accumulation_steps": 1, | ||
"zero_optimization": { | ||
"stage": zero_stage, | ||
"offload_optimizer": {"device": "none", "nvme_path": None}, | ||
"offload_param": {"device": "none", "nvme_path": None}, | ||
"stage3_gather_16bit_weights_on_model_save": False, | ||
}, | ||
"gradient_clipping": 1.0, | ||
"steps_per_print": np.inf, | ||
"bf16": {"enabled": True}, | ||
"fp16": {"enabled": False}, | ||
"zero_allow_untested_optimizer": True, | ||
} | ||
|
||
( | ||
model, | ||
optimizer, | ||
_, | ||
_, | ||
) = deepspeed.initialize( | ||
model=model, | ||
optimizer=optimizer, | ||
config_params=config, | ||
) | ||
|
||
base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) | ||
model.train() | ||
|
||
model_outputs = [] | ||
data = [] | ||
|
||
for _ in range(2): | ||
for batch in train_dataloader: | ||
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): | ||
outputs = model(**batch) | ||
data.append(batch.to("cpu")) | ||
model_outputs.append(outputs.logits.to("cpu")) | ||
loss = outputs.loss | ||
model.backward(loss) | ||
model.step() | ||
for _ in range(accelerator.num_processes): | ||
lr_scheduler.step() | ||
|
||
trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) | ||
model.destroy() | ||
assert ( | ||
trained_model_results["accuracy"] > base_model_results["accuracy"] | ||
), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}' | ||
assert ( | ||
trained_model_results["f1"] > base_model_results["f1"] | ||
), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}' | ||
|
||
return base_model_results, trained_model_results, model_outputs, data | ||
|
||
|
||
def train_integration(zero_stage: int = 1): | ||
set_seed(42) | ||
FP8_RECIPE_KWARGS = {"fp8_format": "HYBRID", "amax_history_len": 32, "amax_compute_algo": "max"} | ||
kwargs_handlers = [FP8RecipeKwargs(backend="TE", **FP8_RECIPE_KWARGS)] | ||
AcceleratorState()._reset_state(True) | ||
deepspeed_plugin = DeepSpeedPlugin( | ||
zero_stage=zero_stage, | ||
zero3_init_flag=zero_stage == 3, | ||
) | ||
accelerator = Accelerator( | ||
mixed_precision="fp8", kwargs_handlers=kwargs_handlers, deepspeed_plugin=deepspeed_plugin | ||
) | ||
accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = 16 | ||
|
||
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities( | ||
MODEL_NAME, accelerator=accelerator | ||
) | ||
|
||
model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) | ||
base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) | ||
model.train() | ||
model_outputs = [] | ||
data = [] | ||
for _ in range(2): | ||
for batch in train_dataloader: | ||
outputs = model(**batch) | ||
data.append(batch.to("cpu")) | ||
model_outputs.append(outputs.logits.to("cpu")) | ||
loss = outputs.loss | ||
accelerator.backward(loss) | ||
optimizer.step() | ||
lr_scheduler.step() | ||
optimizer.zero_grad() | ||
|
||
trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) | ||
model.destroy() | ||
assert ( | ||
trained_model_results["accuracy"] > base_model_results["accuracy"] | ||
), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}' | ||
assert ( | ||
trained_model_results["f1"] > base_model_results["f1"] | ||
), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}' | ||
|
||
return base_model_results, trained_model_results, model_outputs, data | ||
|
||
|
||
if __name__ == "__main__": | ||
# for zero_stage in [1, 2, 3]: | ||
zero_stage = 1 | ||
baseline_not_trained, baseline_trained, baseline_outputs, baseline_data = train_baseline(zero_stage) | ||
accelerator_not_trained, accelerator_trained, accelerator_outputs, accelerator_data = train_integration(zero_stage) | ||
assert ( | ||
baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] | ||
), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' | ||
assert ( | ||
baseline_not_trained["f1"] == accelerator_not_trained["f1"] | ||
), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' | ||
assert ( | ||
baseline_trained["accuracy"] == accelerator_trained["accuracy"] | ||
), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' | ||
assert ( | ||
baseline_trained["f1"] == accelerator_trained["f1"] | ||
), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' | ||
|
||
torch.distributed.destroy_process_group() |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like this is still using TE - intended?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope, needed to update these scripts, thought I had only pushed single-GPU! Though now DDP and FSDP are good, verifying DeepSpeed in a moment