From 0b5a9c16d89f9951112d4b878877d07eecab322d Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Thu, 9 Jan 2025 11:26:25 -0500 Subject: [PATCH 1/9] Bookmark --- benchmarks/fp8/torchao/Dockerfile | 12 ++ benchmarks/fp8/torchao/README.md | 32 ++++ benchmarks/fp8/torchao/ddp.py | 144 +++++++++++++++ benchmarks/fp8/torchao/distrib_deepspeed.py | 190 ++++++++++++++++++++ benchmarks/fp8/torchao/fp8_utils.py | 116 ++++++++++++ benchmarks/fp8/torchao/fsdp.py | 161 +++++++++++++++++ benchmarks/fp8/torchao/non_distributed.py | 125 +++++++++++++ 7 files changed, 780 insertions(+) create mode 100644 benchmarks/fp8/torchao/Dockerfile create mode 100644 benchmarks/fp8/torchao/README.md create mode 100644 benchmarks/fp8/torchao/ddp.py create mode 100644 benchmarks/fp8/torchao/distrib_deepspeed.py create mode 100644 benchmarks/fp8/torchao/fp8_utils.py create mode 100644 benchmarks/fp8/torchao/fsdp.py create mode 100644 benchmarks/fp8/torchao/non_distributed.py diff --git a/benchmarks/fp8/torchao/Dockerfile b/benchmarks/fp8/torchao/Dockerfile new file mode 100644 index 00000000000..88c21934d4e --- /dev/null +++ b/benchmarks/fp8/torchao/Dockerfile @@ -0,0 +1,12 @@ +FROM nvcr.io/nvidia/pytorch:24.07-py3 + +RUN pip install transformers evaluate datasets +RUN git clone https://github.com/huggingface/accelerate.git + +RUN cd accelerate && \ + pip install -e . && \ + cd benchmarks/fp8 + +RUN /bin/bash + + diff --git a/benchmarks/fp8/torchao/README.md b/benchmarks/fp8/torchao/README.md new file mode 100644 index 00000000000..d5abadaf64e --- /dev/null +++ b/benchmarks/fp8/torchao/README.md @@ -0,0 +1,32 @@ +# FP8 Benchmarks + +Comparing and running [torchao](https://github.com/pytorch/ao/tree/main/torchao/float8) FP8 with accelerate + +## Overview + +This repo provides scripts which compare native `torchao` model training against `accelerate`'s own integration. Each modeling type is segmented out via a script, supporting the following: + +* Single GPU training (`non_distributed.py`) +* Multi-GPU training via DistributedDataParallelism (`ddp.py`) +* Fully Sharded Data Parallelism (`fsdp.py`) +* DeepSpeed ZeRO 1-3 (`deepspeed.py`) + +To run them, it's recommended to use a docker image (see the attached `Dockerfile`) and not install `torchao` manually. + +## Running: + +There are official Docker images located at `huggingface/accelerate:gpu-fp8-torchao-nightly` which can be used. + +You can run all scripts using the core `accelerate launch` command without any `accelerate config` being needed. + +For single GPU, run it via `python`: + +```bash +python non_distributed.py +``` + +For the rest, run it via `accelerate launch`: + +```bash +accelerate launch ddp.py # or distrib_deepspeed.py, ddp.py +``` \ No newline at end of file diff --git a/benchmarks/fp8/torchao/ddp.py b/benchmarks/fp8/torchao/ddp.py new file mode 100644 index 00000000000..ba708a27be4 --- /dev/null +++ b/benchmarks/fp8/torchao/ddp.py @@ -0,0 +1,144 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script tests to ensure that `accelerate` performs at the same level as raw `TransformersEngine`. + +This particular script verifies this for DDP training. +""" + +import evaluate +import torch +import transformer_engine.common.recipe as te_recipe +import transformer_engine.pytorch as te +from fp8_utils import evaluate_model, get_named_parameters, get_training_utilities +from torch.nn.parallel import DistributedDataParallel as DDP +from transformer_engine.common.recipe import DelayedScaling + +from accelerate import Accelerator +from accelerate.state import AcceleratorState +from accelerate.utils import FP8RecipeKwargs, set_seed +from accelerate.utils.transformer_engine import convert_model + + +MODEL_NAME = "bert-base-cased" +METRIC = evaluate.load("glue", "mrpc") + + +def train_baseline(): + set_seed(42) + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME) + accelerator = Accelerator() + device = accelerator.device + model.to(device) + + # Convert the model to TE + old_named_params = get_named_parameters(model) + + with torch.no_grad(): + convert_model(model) + + FP8_RECIPE_KWARGS = {"fp8_format": te_recipe.Format.HYBRID, "amax_history_len": 32, "amax_compute_algo": "max"} + fp8_recipe = DelayedScaling(**FP8_RECIPE_KWARGS) + + new_named_params = get_named_parameters(model) + + # Convert the model to DDP + device_ids, output_device = [accelerator.local_process_index], accelerator.local_process_index + model = DDP(model, device_ids=device_ids, output_device=output_device) + + mapping = {p: new_named_params[n] for n, p in old_named_params.items()} + for param_group in optimizer.param_groups: + param_group["params"] = [mapping[p] for p in param_group["params"]] + + base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + model.train() + + for _ in range(2): + for batch in train_dataloader: + with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + batch = batch.to(device) + outputs = model(**batch) + loss = outputs.loss + loss.backward() + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() + + trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + + assert ( + trained_model_results["accuracy"] > base_model_results["accuracy"] + ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}' + assert ( + trained_model_results["f1"] > base_model_results["f1"] + ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}' + + return base_model_results, trained_model_results + + +def train_integration(): + FP8_RECIPE_KWARGS = {"fp8_format": "HYBRID", "amax_history_len": 32, "amax_compute_algo": "max"} + kwargs_handlers = [FP8RecipeKwargs(backend="TE", **FP8_RECIPE_KWARGS)] + AcceleratorState()._reset_state(True) + accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=kwargs_handlers) + set_seed(42) + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities( + MODEL_NAME, accelerator=accelerator + ) + + model, optimizer = accelerator.prepare(model, optimizer) + base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + model.train() + + for _ in range(2): + for batch in train_dataloader: + outputs = model(**batch) + loss = outputs.loss + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() + + trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + + assert ( + trained_model_results["accuracy"] > base_model_results["accuracy"] + ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}' + assert ( + trained_model_results["f1"] > base_model_results["f1"] + ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}' + + return base_model_results, trained_model_results + + +if __name__ == "__main__": + baseline_not_trained, baseline_trained = train_baseline() + accelerator_not_trained, accelerator_trained = train_integration() + + assert ( + baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] + ), f'Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' + assert ( + baseline_not_trained["f1"] == accelerator_not_trained["f1"] + ), f'F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' + assert ( + baseline_trained["accuracy"] == accelerator_trained["accuracy"] + ), f'Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' + assert ( + baseline_trained["f1"] == accelerator_trained["f1"] + ), f'F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' + + torch.distributed.destroy_process_group() diff --git a/benchmarks/fp8/torchao/distrib_deepspeed.py b/benchmarks/fp8/torchao/distrib_deepspeed.py new file mode 100644 index 00000000000..e678deb3659 --- /dev/null +++ b/benchmarks/fp8/torchao/distrib_deepspeed.py @@ -0,0 +1,190 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script tests to ensure that `accelerate` performs at the same level as raw `TransformersEngine`. + +This particular script verifies this for DDP training. +""" + +from unittest.mock import patch + +import deepspeed +import evaluate +import torch +import transformer_engine.common.recipe as te_recipe +import transformer_engine.pytorch as te +from fp8_utils import evaluate_model, get_named_parameters, get_training_utilities +from transformer_engine.common.recipe import DelayedScaling + +from accelerate import Accelerator, DeepSpeedPlugin +from accelerate.state import AcceleratorState +from accelerate.utils import FP8RecipeKwargs, set_seed +from accelerate.utils.transformer_engine import convert_model + + +MODEL_NAME = "bert-base-cased" +METRIC = evaluate.load("glue", "mrpc") + + +def train_baseline(zero_stage: int = 1): + # This forces transformers to think Zero-3 Init should be used + with patch("transformers.integrations.deepspeed.is_deepspeed_zero3_enabled") as mock: + mock.return_value = zero_stage == 3 + set_seed(42) + + accelerator = Accelerator() + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities( + MODEL_NAME, accelerator=accelerator + ) + + # Convert the model to TE + old_named_params = get_named_parameters(model) + + with torch.no_grad(): + convert_model(model) + new_named_params = get_named_parameters(model) + + mapping = {p: new_named_params[n] for n, p in old_named_params.items()} + for param_group in optimizer.param_groups: + param_group["params"] = [mapping[p] for p in param_group["params"]] + + FP8_RECIPE_KWARGS = {"fp8_format": te_recipe.Format.HYBRID, "amax_history_len": 32, "amax_compute_algo": "max"} + fp8_recipe = DelayedScaling(**FP8_RECIPE_KWARGS) + + import numpy as np + + config = { + "train_batch_size": 32, + "train_micro_batch_size_per_gpu": 16, + "gradient_accumulation_steps": 1, + "zero_optimization": { + "stage": zero_stage, + "offload_optimizer": {"device": "none", "nvme_path": None}, + "offload_param": {"device": "none", "nvme_path": None}, + "stage3_gather_16bit_weights_on_model_save": False, + }, + "gradient_clipping": 1.0, + "steps_per_print": np.inf, + "bf16": {"enabled": True}, + "fp16": {"enabled": False}, + "zero_allow_untested_optimizer": True, + } + + ( + model, + optimizer, + _, + _, + ) = deepspeed.initialize( + model=model, + optimizer=optimizer, + config_params=config, + ) + + base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + model.train() + + model_outputs = [] + data = [] + + for _ in range(2): + for batch in train_dataloader: + with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): + outputs = model(**batch) + data.append(batch.to("cpu")) + model_outputs.append(outputs.logits.to("cpu")) + loss = outputs.loss + model.backward(loss) + model.step() + for _ in range(accelerator.num_processes): + lr_scheduler.step() + + trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + model.destroy() + assert ( + trained_model_results["accuracy"] > base_model_results["accuracy"] + ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}' + assert ( + trained_model_results["f1"] > base_model_results["f1"] + ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}' + + return base_model_results, trained_model_results, model_outputs, data + + +def train_integration(zero_stage: int = 1): + set_seed(42) + FP8_RECIPE_KWARGS = {"fp8_format": "HYBRID", "amax_history_len": 32, "amax_compute_algo": "max"} + kwargs_handlers = [FP8RecipeKwargs(backend="TE", **FP8_RECIPE_KWARGS)] + AcceleratorState()._reset_state(True) + deepspeed_plugin = DeepSpeedPlugin( + zero_stage=zero_stage, + zero3_init_flag=zero_stage == 3, + ) + accelerator = Accelerator( + mixed_precision="fp8", kwargs_handlers=kwargs_handlers, deepspeed_plugin=deepspeed_plugin + ) + accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = 16 + + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities( + MODEL_NAME, accelerator=accelerator + ) + + model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) + base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + model.train() + model_outputs = [] + data = [] + for _ in range(2): + for batch in train_dataloader: + outputs = model(**batch) + data.append(batch.to("cpu")) + model_outputs.append(outputs.logits.to("cpu")) + loss = outputs.loss + accelerator.backward(loss) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + model.destroy() + assert ( + trained_model_results["accuracy"] > base_model_results["accuracy"] + ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}' + assert ( + trained_model_results["f1"] > base_model_results["f1"] + ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}' + + return base_model_results, trained_model_results, model_outputs, data + + +if __name__ == "__main__": + # for zero_stage in [1, 2, 3]: + zero_stage = 1 + baseline_not_trained, baseline_trained, baseline_outputs, baseline_data = train_baseline(zero_stage) + accelerator_not_trained, accelerator_trained, accelerator_outputs, accelerator_data = train_integration(zero_stage) + assert ( + baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] + ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' + assert ( + baseline_not_trained["f1"] == accelerator_not_trained["f1"] + ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' + assert ( + baseline_trained["accuracy"] == accelerator_trained["accuracy"] + ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' + assert ( + baseline_trained["f1"] == accelerator_trained["f1"] + ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' + + torch.distributed.destroy_process_group() diff --git a/benchmarks/fp8/torchao/fp8_utils.py b/benchmarks/fp8/torchao/fp8_utils.py new file mode 100644 index 00000000000..d28702e05ff --- /dev/null +++ b/benchmarks/fp8/torchao/fp8_utils.py @@ -0,0 +1,116 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + + +def get_dataloaders(model_name: str, batch_size: int = 16): + from datasets import load_dataset + from torch.utils.data import DataLoader + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(model_name) + datasets = load_dataset("glue", "mrpc") + + def tokenize_function(examples): + # max_length=None => use the model max length (it's actually the default) + outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None) + return outputs + + # Apply the method we just defined to all the examples in all the splits of the dataset + # starting with the main process first: + tokenized_datasets = datasets.map( + tokenize_function, + batched=True, + remove_columns=["idx", "sentence1", "sentence2"], + ) + + # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the + # transformers library + tokenized_datasets = tokenized_datasets.rename_column("label", "labels") + + def collate_fn(examples): + return tokenizer.pad( + examples, + padding="longest", + pad_to_multiple_of=16, # Specific for FP8 + return_tensors="pt", + ) + + # Instantiate dataloaders. + train_dataloader = DataLoader( + tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size, drop_last=True + ) + eval_dataloader = DataLoader( + tokenized_datasets["validation"], + shuffle=False, + collate_fn=collate_fn, + batch_size=16, + drop_last=True, + ) + + return train_dataloader, eval_dataloader + + +def get_training_utilities(model_name: str, batch_size: int = 16, accelerator=None): + """ + Returns a tuple of: + - Model + - Optimizer + - Train dataloader (prepared) + - Eval dataloader (prepared) + - LR Scheduler + Suitable for training on the MRPC dataset + """ + from torch.optim import AdamW + from transformers import AutoModelForSequenceClassification, get_linear_schedule_with_warmup + + from accelerate import Accelerator + + if accelerator is None: + accelerator = Accelerator() + model = AutoModelForSequenceClassification.from_pretrained(model_name) + train_dataloader, eval_dataloader = get_dataloaders(model_name, batch_size) + optimizer = AdamW(model.parameters(), lr=0.0001) + lr_scheduler = get_linear_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=100, + num_training_steps=len(train_dataloader) * 2, + ) + train_dataloader, eval_dataloader = accelerator.prepare(train_dataloader, eval_dataloader) + return model, optimizer, train_dataloader, eval_dataloader, lr_scheduler + + +def get_named_parameters(model): + """ + Same thing as `Accelerator.get_named_parameters` Returns a list of the named parameters of the model (extracted + from parallel) + """ + from accelerate.utils import extract_model_from_parallel + + model = extract_model_from_parallel(model) + return {n: p for n, p in model.named_parameters()} + + +def evaluate_model(model, dataloader, metric, accelerator=None): + "Turns model to .eval(), runs dataloader, calculates metric, then turns eval back on" + model.eval() + for step, batch in enumerate(dataloader): + with torch.no_grad(): + outputs = model(**batch) + predictions = outputs.logits.argmax(dim=-1) + references = batch["labels"] + if accelerator is not None and accelerator.num_processes > 1: + predictions, references = accelerator.gather_for_metrics((predictions, references)) + metric.add_batch(predictions=predictions, references=references) + return metric.compute() diff --git a/benchmarks/fp8/torchao/fsdp.py b/benchmarks/fp8/torchao/fsdp.py new file mode 100644 index 00000000000..418122185e1 --- /dev/null +++ b/benchmarks/fp8/torchao/fsdp.py @@ -0,0 +1,161 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script tests to ensure that `accelerate` performs at the same level as raw `TransformersEngine`. + +This particular script verifies this for FSDP training. +""" + +from functools import partial + +import evaluate +import torch +import transformer_engine.common.recipe as te_recipe +import transformer_engine.pytorch as te +from fp8_utils import evaluate_model, get_named_parameters, get_training_utilities +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import MixedPrecision +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy +from transformer_engine.common.recipe import DelayedScaling +from transformers.models.bert import BertLayer + +from accelerate import Accelerator +from accelerate import FullyShardedDataParallelPlugin as FSDPPlugin +from accelerate.state import AcceleratorState +from accelerate.utils import FP8RecipeKwargs, set_seed +from accelerate.utils.transformer_engine import convert_model + + +MODEL_NAME = "bert-base-cased" +METRIC = evaluate.load("glue", "mrpc") + +FSDP_WRAP_POLICY = partial(transformer_auto_wrap_policy, transformer_layer_cls={BertLayer}) + + +def train_baseline(): + set_seed(42) + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME) + accelerator = Accelerator() + device = accelerator.device + model.to(device) + + # Convert the model to TE + old_named_params = get_named_parameters(model) + + with torch.no_grad(): + convert_model(model) + + FP8_RECIPE_KWARGS = {"fp8_format": te_recipe.Format.HYBRID, "amax_history_len": 32, "amax_compute_algo": "max"} + fp8_recipe = DelayedScaling(**FP8_RECIPE_KWARGS) + + new_named_params = get_named_parameters(model) + + # Convert the model to FSDP + model = FSDP( + model, + use_orig_params=True, + mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32), + auto_wrap_policy=FSDP_WRAP_POLICY, + ) + + mapping = {p: new_named_params[n] for n, p in old_named_params.items()} + for param_group in optimizer.param_groups: + param_group["params"] = [mapping[p] for p in param_group["params"]] + + base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + model.train() + + for _ in range(2): + for batch in train_dataloader: + with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + batch = batch.to(device) + outputs = model(**batch) + loss = outputs.loss + loss.backward() + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() + + trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + + assert ( + trained_model_results["accuracy"] > base_model_results["accuracy"] + ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}' + assert ( + trained_model_results["f1"] > base_model_results["f1"] + ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}' + + return base_model_results, trained_model_results + + +def train_integration(): + FP8_RECIPE_KWARGS = {"fp8_format": "HYBRID", "amax_history_len": 32, "amax_compute_algo": "max"} + kwargs_handlers = [FP8RecipeKwargs(backend="TE", **FP8_RECIPE_KWARGS)] + AcceleratorState()._reset_state(True) + fsdp_plugin = FSDPPlugin( + auto_wrap_policy=FSDP_WRAP_POLICY, + use_orig_params=True, + mixed_precision_policy=MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32), + ) + accelerator = Accelerator(mixed_precision="fp8", fsdp_plugin=fsdp_plugin, kwargs_handlers=kwargs_handlers) + set_seed(42) + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities( + MODEL_NAME, accelerator=accelerator + ) + + model, optimizer = accelerator.prepare(model, optimizer) + base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + model.train() + + for _ in range(2): + for batch in train_dataloader: + outputs = model(**batch) + loss = outputs.loss + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() + + trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + + assert ( + trained_model_results["accuracy"] > base_model_results["accuracy"] + ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}' + assert ( + trained_model_results["f1"] > base_model_results["f1"] + ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}' + + return base_model_results, trained_model_results + + +if __name__ == "__main__": + baseline_not_trained, baseline_trained = train_baseline() + accelerator_not_trained, accelerator_trained = train_integration() + + assert ( + baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] + ), f'Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' + assert ( + baseline_not_trained["f1"] == accelerator_not_trained["f1"] + ), f'F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' + assert ( + baseline_trained["accuracy"] == accelerator_trained["accuracy"] + ), f'Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' + assert ( + baseline_trained["f1"] == accelerator_trained["f1"] + ), f'F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' + + torch.distributed.destroy_process_group() diff --git a/benchmarks/fp8/torchao/non_distributed.py b/benchmarks/fp8/torchao/non_distributed.py new file mode 100644 index 00000000000..81ebec12fe1 --- /dev/null +++ b/benchmarks/fp8/torchao/non_distributed.py @@ -0,0 +1,125 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script tests to ensure that `accelerate` performs at the same level as raw `torchao`. + +This particular script verifies this for single GPU training. +""" + +import evaluate +import torch +from fp8_utils import evaluate_model, get_named_parameters, get_training_utilities + +from accelerate import Accelerator +from accelerate.state import AcceleratorState +from accelerate.utils import FP8RecipeKwargs, set_seed + +from torchao.float8 import convert_to_float8_training + +MODEL_NAME = "bert-base-cased" +METRIC = evaluate.load("glue", "mrpc") + + +def module_filter_func(module, *args): + if isinstance(module, torch.nn.Linear): + if module.in_features % 16 != 0 or module.out_features % 16 != 0: + return False + + return True + + +def train_baseline(): + set_seed(42) + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME) + model.to("cuda") + convert_to_float8_training(model, module_filter_fn=module_filter_func) + base_model_results = evaluate_model(model, eval_dataloader, METRIC) + model.train() + + from accelerate.utils.modeling import get_mixed_precision_context_manager + from accelerate.utils.operations import convert_outputs_to_fp32 + autocast_context = get_mixed_precision_context_manager(True, {"dtype": torch.bfloat16}) + model_forward_func = model.forward + model.forward = convert_outputs_to_fp32(autocast_context(model_forward_func)) + + for batch in train_dataloader: + outputs = model(**batch) + loss = outputs.loss + loss.backward() + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() + + trained_model_results = evaluate_model(model, eval_dataloader, METRIC) + + assert ( + trained_model_results["accuracy"] > base_model_results["accuracy"] + ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}' + assert ( + trained_model_results["f1"] > base_model_results["f1"] + ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}' + + return base_model_results, trained_model_results + + +def train_integration(): + FP8_RECIPE_KWARGS = {"fp8_format": "HYBRID", "amax_history_len": 32, "amax_compute_algo": "max"} + kwargs_handlers = [FP8RecipeKwargs(backend="TE", **FP8_RECIPE_KWARGS)] + AcceleratorState()._reset_state(True) + accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=kwargs_handlers) + set_seed(42) + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities( + MODEL_NAME, accelerator=accelerator + ) + + model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) + base_model_results = evaluate_model(model, eval_dataloader, METRIC) + model.train() + + for batch in train_dataloader: + outputs = model(**batch) + loss = outputs.loss + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() + + trained_model_results = evaluate_model(model, eval_dataloader, METRIC) + + assert ( + trained_model_results["accuracy"] > base_model_results["accuracy"] + ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}' + assert ( + trained_model_results["f1"] > base_model_results["f1"] + ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}' + + return base_model_results, trained_model_results + + +if __name__ == "__main__": + baseline_not_trained, baseline_trained = train_baseline() + # accelerator_not_trained, accelerator_trained = train_integration() + # assert ( + # baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] + # ), f'Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' + # assert ( + # baseline_not_trained["f1"] == accelerator_not_trained["f1"] + # ), f'F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' + # assert ( + # baseline_trained["accuracy"] == accelerator_trained["accuracy"] + # ), f'Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' + # assert ( + # baseline_trained["f1"] == accelerator_trained["f1"] + # ), f'F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' From b2cce71d3a9d272638209368b81c55a27c4807ab Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Fri, 10 Jan 2025 10:09:32 -0500 Subject: [PATCH 2/9] bookmark --- benchmarks/fp8/torchao/non_distributed.py | 139 +++++++++++++++------- 1 file changed, 94 insertions(+), 45 deletions(-) diff --git a/benchmarks/fp8/torchao/non_distributed.py b/benchmarks/fp8/torchao/non_distributed.py index 81ebec12fe1..08d1da99b74 100644 --- a/benchmarks/fp8/torchao/non_distributed.py +++ b/benchmarks/fp8/torchao/non_distributed.py @@ -20,18 +20,106 @@ import evaluate import torch -from fp8_utils import evaluate_model, get_named_parameters, get_training_utilities +from datasets import load_dataset +from torch.optim import AdamW +from torch.utils.data import DataLoader +from torchao.float8 import convert_to_float8_training +from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup from accelerate import Accelerator from accelerate.state import AcceleratorState from accelerate.utils import FP8RecipeKwargs, set_seed -from torchao.float8 import convert_to_float8_training MODEL_NAME = "bert-base-cased" METRIC = evaluate.load("glue", "mrpc") +def get_dataloaders(model_name: str, batch_size: int = 16): + tokenizer = AutoTokenizer.from_pretrained(model_name) + datasets = load_dataset("glue", "mrpc") + + def tokenize_function(examples): + # max_length=None => use the model max length (it's actually the default) + outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None) + return outputs + + # Apply the method we just defined to all the examples in all the splits of the dataset + # starting with the main process first: + tokenized_datasets = datasets.map( + tokenize_function, + batched=True, + remove_columns=["idx", "sentence1", "sentence2"], + ) + + # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the + # transformers library + tokenized_datasets = tokenized_datasets.rename_column("label", "labels") + + def collate_fn(examples): + return tokenizer.pad( + examples, + padding="longest", + pad_to_multiple_of=16, # Specific for FP8 + return_tensors="pt", + ) + + # Instantiate dataloaders. + train_dataloader = DataLoader( + tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size, drop_last=True + ) + eval_dataloader = DataLoader( + tokenized_datasets["validation"], + shuffle=False, + collate_fn=collate_fn, + batch_size=16, + drop_last=True, + ) + + return train_dataloader, eval_dataloader + + +def get_training_utilities(model_name: str, batch_size: int = 16, accelerator=None): + """ + Returns a tuple of: + - Model + - Optimizer + - Train dataloader (prepared) + - Eval dataloader (prepared) + - LR Scheduler + Suitable for training on the MRPC dataset + """ + + if accelerator is None: + accelerator = Accelerator() + model = AutoModelForSequenceClassification.from_pretrained(model_name) + train_dataloader, eval_dataloader = get_dataloaders(model_name, batch_size) + optimizer = AdamW(model.parameters(), lr=0.0001) + lr_scheduler = get_linear_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=100, + num_training_steps=len(train_dataloader) * 2, + ) + train_dataloader, eval_dataloader = accelerator.prepare(train_dataloader, eval_dataloader) + return model, optimizer, train_dataloader, eval_dataloader, lr_scheduler + + + + +def evaluate_model(model, dataloader, metric, accelerator=None): + "Turns model to .eval(), runs dataloader, calculates metric, then turns eval back on" + model.eval() + for step, batch in enumerate(dataloader): + with torch.no_grad(): + outputs = model(**batch) + predictions = outputs.logits.argmax(dim=-1) + references = batch["labels"] + if accelerator is not None and accelerator.num_processes > 1: + predictions, references = accelerator.gather_for_metrics((predictions, references)) + metric.add_batch(predictions=predictions, references=references) + return metric.compute() + + def module_filter_func(module, *args): if isinstance(module, torch.nn.Linear): if module.in_features % 16 != 0 or module.out_features % 16 != 0: @@ -48,50 +136,11 @@ def train_baseline(): base_model_results = evaluate_model(model, eval_dataloader, METRIC) model.train() - from accelerate.utils.modeling import get_mixed_precision_context_manager - from accelerate.utils.operations import convert_outputs_to_fp32 - autocast_context = get_mixed_precision_context_manager(True, {"dtype": torch.bfloat16}) - model_forward_func = model.forward - model.forward = convert_outputs_to_fp32(autocast_context(model_forward_func)) - - for batch in train_dataloader: - outputs = model(**batch) - loss = outputs.loss - loss.backward() - optimizer.step() - optimizer.zero_grad() - lr_scheduler.step() - - trained_model_results = evaluate_model(model, eval_dataloader, METRIC) - - assert ( - trained_model_results["accuracy"] > base_model_results["accuracy"] - ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}' - assert ( - trained_model_results["f1"] > base_model_results["f1"] - ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}' - - return base_model_results, trained_model_results - - -def train_integration(): - FP8_RECIPE_KWARGS = {"fp8_format": "HYBRID", "amax_history_len": 32, "amax_compute_algo": "max"} - kwargs_handlers = [FP8RecipeKwargs(backend="TE", **FP8_RECIPE_KWARGS)] - AcceleratorState()._reset_state(True) - accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=kwargs_handlers) - set_seed(42) - model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities( - MODEL_NAME, accelerator=accelerator - ) - - model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) - base_model_results = evaluate_model(model, eval_dataloader, METRIC) - model.train() - for batch in train_dataloader: - outputs = model(**batch) - loss = outputs.loss - accelerator.backward(loss) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + outputs = model(**batch) + loss = outputs.loss + loss.backward() optimizer.step() optimizer.zero_grad() lr_scheduler.step() From e1a130411fb667fe46bc90baba6fd119fdc162d4 Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Thu, 16 Jan 2025 08:56:35 -0500 Subject: [PATCH 3/9] Add torchao base example --- benchmarks/fp8/torchao/non_distributed.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/benchmarks/fp8/torchao/non_distributed.py b/benchmarks/fp8/torchao/non_distributed.py index 08d1da99b74..732a23a2570 100644 --- a/benchmarks/fp8/torchao/non_distributed.py +++ b/benchmarks/fp8/torchao/non_distributed.py @@ -20,6 +20,7 @@ import evaluate import torch +from functools import partial from datasets import load_dataset from torch.optim import AdamW from torch.utils.data import DataLoader @@ -104,8 +105,6 @@ def get_training_utilities(model_name: str, batch_size: int = 16, accelerator=No return model, optimizer, train_dataloader, eval_dataloader, lr_scheduler - - def evaluate_model(model, dataloader, metric, accelerator=None): "Turns model to .eval(), runs dataloader, calculates metric, then turns eval back on" model.eval() @@ -120,19 +119,31 @@ def evaluate_model(model, dataloader, metric, accelerator=None): return metric.compute() -def module_filter_func(module, *args): +def module_filter_func(module, fqn, first_layer_name=None, last_layer_name=None): if isinstance(module, torch.nn.Linear): if module.in_features % 16 != 0 or module.out_features % 16 != 0: return False - + # For stability reasons, we skip the first and last linear layers + # Otherwise can lead to the model not training or converging properly + if fqn in (first_layer_name, last_layer_name): + return False return True def train_baseline(): set_seed(42) model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME) + first_linear = None + last_linear = None + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear): + if first_linear is None: + first_linear = name + last_linear = name + + func = partial(module_filter_func, first_layer_name=first_linear, last_layer_name=last_linear) model.to("cuda") - convert_to_float8_training(model, module_filter_fn=module_filter_func) + convert_to_float8_training(model, module_filter_fn=func) base_model_results = evaluate_model(model, eval_dataloader, METRIC) model.train() From be210dbaa923760549dccda40428e056480aa434 Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Thu, 16 Jan 2025 10:46:05 -0500 Subject: [PATCH 4/9] Currently broken --- benchmarks/fp8/torchao/non_distributed.py | 39 ++++- src/accelerate/accelerator.py | 112 +++++++++----- src/accelerate/utils/__init__.py | 6 + src/accelerate/utils/ao.py | 112 ++++++++++++++ src/accelerate/utils/dataclasses.py | 176 +++++++++++++--------- src/accelerate/utils/imports.py | 20 +++ 6 files changed, 348 insertions(+), 117 deletions(-) create mode 100644 src/accelerate/utils/ao.py diff --git a/benchmarks/fp8/torchao/non_distributed.py b/benchmarks/fp8/torchao/non_distributed.py index 732a23a2570..81eb0d2bc73 100644 --- a/benchmarks/fp8/torchao/non_distributed.py +++ b/benchmarks/fp8/torchao/non_distributed.py @@ -28,8 +28,7 @@ from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup from accelerate import Accelerator -from accelerate.state import AcceleratorState -from accelerate.utils import FP8RecipeKwargs, set_seed +from accelerate.utils import AORecipeKwargs, set_seed MODEL_NAME = "bert-base-cased" @@ -119,7 +118,7 @@ def evaluate_model(model, dataloader, metric, accelerator=None): return metric.compute() -def module_filter_func(module, fqn, first_layer_name=None, last_layer_name=None): +def filter_linear_layers(module, fqn, first_layer_name=None, last_layer_name=None): if isinstance(module, torch.nn.Linear): if module.in_features % 16 != 0 or module.out_features % 16 != 0: return False @@ -141,7 +140,7 @@ def train_baseline(): first_linear = name last_linear = name - func = partial(module_filter_func, first_layer_name=first_linear, last_layer_name=last_linear) + func = partial(filter_linear_layers, first_layer_name=first_linear, last_layer_name=last_linear) model.to("cuda") convert_to_float8_training(model, module_filter_fn=func) base_model_results = evaluate_model(model, eval_dataloader, METRIC) @@ -168,9 +167,37 @@ def train_baseline(): return base_model_results, trained_model_results +def train_integration(): + set_seed(42) + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME) + accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=[AORecipeKwargs()]) + model = accelerator.prepare(model) + base_model_results = evaluate_model(model, eval_dataloader, METRIC) + model.train() + + for batch in train_dataloader: + outputs = model(**batch) + loss = outputs.loss + loss.backward() + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() + + trained_model_results = evaluate_model(model, eval_dataloader, METRIC) + + assert ( + trained_model_results["accuracy"] > base_model_results["accuracy"] + ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}' + assert ( + trained_model_results["f1"] > base_model_results["f1"] + ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}' + + return base_model_results, trained_model_results + + if __name__ == "__main__": - baseline_not_trained, baseline_trained = train_baseline() - # accelerator_not_trained, accelerator_trained = train_integration() + # baseline_not_trained, baseline_trained = train_baseline() + accelerator_not_trained, accelerator_trained = train_integration() # assert ( # baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] # ), f'Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 274efa019f7..923f2693764 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -29,6 +29,7 @@ from types import MethodType from typing import Any, Callable, Union +from accelerate.utils.imports import is_torchao_available import torch import torch.utils.hooks as hooks from huggingface_hub import split_torch_state_dict_into_shards @@ -49,6 +50,9 @@ WEIGHTS_NAME, WEIGHTS_PATTERN_NAME, AutocastKwargs, + AORecipeKwargs, + TERecipeKwargs, + MSAMPRecipeKwargs, DataLoaderConfiguration, DeepSpeedPlugin, DistributedDataParallelKwargs, @@ -72,6 +76,7 @@ clean_state_dict_for_safetensors, compare_versions, convert_model, + convert_to_float8_training, convert_outputs_to_fp32, ensure_weights_retied, extract_model_from_parallel, @@ -381,45 +386,39 @@ def __init__( self.scaler_handler = None self.init_handler = None self.fp8_recipe_handler = None + self.ao_recipe_handler = None + self.te_recipe_handler = None + self.msamp_recipe_handler = None self.autocast_handler = None self.profile_handler = None self.has_lomo_optimizer = False + found_handlers = set() + handler_class_to_attr = { + DistributedDataParallelKwargs: "ddp_handler", + GradScalerKwargs: "scaler_handler", + InitProcessGroupKwargs: "init_handler", + FP8RecipeKwargs: "fp8_recipe_handler", + AutocastKwargs: "autocast_handler", + ProfileKwargs: "profile_handler", + AORecipeKwargs: "ao_recipe_handler", + TERecipeKwargs: "te_recipe_handler", + MSAMPRecipeKwargs: "msamp_recipe_handler", + } + self.has_fp8_handler = False if kwargs_handlers is not None: for handler in kwargs_handlers: assert isinstance( handler, KwargsHandler ), f"Unsupported kwargs handler passed: {handler}, must be one that inherits `accelerate.utils.KwargsHandler`." - if isinstance(handler, DistributedDataParallelKwargs): - if self.ddp_handler is not None: - raise ValueError("You can only pass one `DistributedDataParallelKwargs` in `kwargs_handler`.") - else: - self.ddp_handler = handler - elif isinstance(handler, GradScalerKwargs): - if self.scaler_handler is not None: - raise ValueError("You can only pass one `GradScalerKwargs` in `kwargs_handler`.") - else: - self.scaler_handler = handler - elif isinstance(handler, InitProcessGroupKwargs): - if self.init_handler is not None: - raise ValueError("You can only pass one `InitProcessGroupKwargs` in `kwargs_handler`.") - else: - self.init_handler = handler - elif isinstance(handler, FP8RecipeKwargs): - if self.fp8_recipe_handler is not None: - raise ValueError("You can only pass one `FP8RecipeKwargs` in `kwargs_handler`.") - else: - self.fp8_recipe_handler = handler - elif isinstance(handler, AutocastKwargs): - if self.autocast_handler is not None: - raise ValueError("You can only pass one `AutocastKwargs` in `kwargs_handler`.") - else: - self.autocast_handler = handler - elif isinstance(handler, ProfileKwargs): - if self.profile_handler is not None: - raise ValueError("You can only pass one `ProfileKwargs` in `kwargs_handler`.") - else: - self.profile_handler = handler + # Add the handler class to the set of found handlers + if handler.__class__ in found_handlers: + raise ValueError(f"You can only pass one {handler.__class__} in `kwargs_handlers`.") + found_handlers.add(handler.__class__) + handler_attr = handler_class_to_attr[handler.__class__] + setattr(self, handler_attr, handler) + if "recipe_handler" in handler_attr and not self.has_fp8_handler: + self.has_fp8_handler = True kwargs = self.init_handler.to_kwargs() if self.init_handler is not None else {} self.state = AcceleratorState( @@ -433,17 +432,27 @@ def __init__( **kwargs, ) - if self.state.mixed_precision == "fp8" and self.fp8_recipe_handler is None: - self.fp8_recipe_handler = FP8RecipeKwargs() + # Check for automatic FP8 recipe creation + if self.state.mixed_precision == "fp8" and not self.has_fp8_handler: + # Prioritize TE -> AO -> MSAMP + if is_torchao_available(): + self.ao_recipe_handler = AORecipeKwargs() + elif is_transformer_engine_available(): + self.te_recipe_handler = TERecipeKwargs() + elif is_msamp_available(): + self.msamp_recipe_handler = MSAMPRecipeKwargs() + else: + raise ImportError("Tried to train with `fp8` and auto-detect backend, but no FP8-compatible backend was installed.") self.delayed_fp8_autocast = False - if self.fp8_recipe_handler is not None: + if self.has_fp8_handler: # We already check if FP8 is available during `self.state` if self.state.mixed_precision != "fp8" and ( self.distributed_type not in (DistributedType.FSDP, DistributedType.DEEPSPEED) ): - raise ValueError("Passing in a `FP8RecipeKwargs` object requires setting `mixed_precision='fp8'`.") - self.delayed_fp8_autocast = self.fp8_recipe_handler.backend == "TE" and self.distributed_type in ( + raise ValueError("Passing in an FP8 configuration requires setting `mixed_precision='fp8'`.") + # DEPRECATE once 2.0 is released + self.delayed_fp8_autocast = self.fp8_backend == "TE" and self.distributed_type in ( DistributedType.MULTI_GPU, DistributedType.FSDP, ) @@ -1329,6 +1338,8 @@ def prepare(self, *args, device_placement=None): args = self._prepare_ipex_or_xpu(*args) if self.fp8_backend == "TE": args = self._prepare_te(*args) + elif self.fp8_backend == "AO": + args = self._prepare_ao(*args) if self.distributed_type == DistributedType.DEEPSPEED: result = self._prepare_deepspeed(*args) elif self.distributed_type == DistributedType.MEGATRON_LM: @@ -1414,7 +1425,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e # We prepare TE after, allowing for bf16 autocast to happen first if self.fp8_backend == "TE" and not self.delayed_fp8_autocast: - model = apply_fp8_autowrap(model, self.fp8_recipe_handler) + model = apply_fp8_autowrap(model, self.te_recipe_handler or self.fp8_recipe_handler) if (getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False)) and getattr( model, "hf_device_map", False @@ -1605,7 +1616,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e model = xmp.MpModelWrapper(model).to(self.device) # Now we can apply the FP8 autocast if self.delayed_fp8_autocast: - model = apply_fp8_autowrap(model, self.fp8_recipe_handler) + model = apply_fp8_autowrap(model, self.te_recipe_handler or self.fp8_recipe_handler) # torch.compile should be called last and only if the model isn't already compiled. if self.state.dynamo_plugin.backend != DynamoBackend.NO and not is_compiled_module(model): if not is_torch_version(">=", "2.0"): @@ -1613,6 +1624,13 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e model = torch.compile(model, **self.state.dynamo_plugin.to_kwargs()) return model + def _prepare_ao(self, *args): + if not is_torchao_available(): + raise ImportError("`torchao` was not found on your system. Please ensure that `torchao` is installed") + for model in self._models: + convert_to_float8_training(model, config=self.ao_recipe_handler.config, module_filter_func=self.ao_recipe_handler.module_filter_func) + return args + def _prepare_te(self, *args): if not is_transformer_engine_available(): raise ImportError( @@ -1767,7 +1785,7 @@ def _prepare_deepspeed(self, *args): if model is not None: # If we are using FP8, we need to apply the autowrap now - if getattr(self.fp8_recipe_handler, "backend", None) == "TE": + if self.fp8_backend == "TE": model = apply_fp8_autowrap(model, self.fp8_recipe_handler) # if the model is an MOE, set the appropriate MOE layers as leaf Z3 modules deepspeed_plugin.set_moe_leaf_modules(model) @@ -2062,7 +2080,12 @@ def _prepare_msamp(self, *args, device_placement): f"You can't use multiple models ({num_models}) or optimizers {num_optimizers} with MS-AMP." ) else: - model, optimizer = msamp.initialize(model, optimizer, opt_level=self.fp8_recipe_handler.opt_level) + # DEPRECATE @ 2.0 + if self.fp8_recipe_handler is not None: + opt_level = self.fp8_recipe_handler.opt_level + else: + opt_level = self.msamp_recipe_handler.opt_level + model, optimizer = msamp.initialize(model, optimizer, opt_level=opt_level) for i in range(len(result)): if isinstance(result[i], torch.nn.Module): result[i] = model @@ -3602,8 +3625,15 @@ def lomo_backward(self, loss: torch.Tensor, learning_rate: float) -> None: @property def fp8_backend(self): "Returns the configured backend for training in FP8" - if self.mixed_precision == "fp8" and self.fp8_recipe_handler is not None: - return self.fp8_recipe_handler.backend + if self.has_fp8_handler: + if self.fp8_recipe_handler is not None: + return self.fp8_recipe_handler.backend + elif self.ao_recipe_handler is not None: + return "AO" + elif self.te_recipe_handler is not None: + return "TE" + elif self.msamp_recipe_handler is not None: + return "MSAMP" elif self.state.deepspeed_plugin is not None and self.state.deepspeed_plugin.enable_msamp: return "MSAMP" return None diff --git a/src/accelerate/utils/__init__.py b/src/accelerate/utils/__init__.py index 273808597b5..6219adbfee0 100644 --- a/src/accelerate/utils/__init__.py +++ b/src/accelerate/utils/__init__.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from .ao import convert_to_float8_training from .constants import ( MITA_PROFILING_AVAILABLE_PYTORCH_VERSION, MODEL_NAME, @@ -33,6 +34,7 @@ ) from .dataclasses import ( AutocastKwargs, + AORecipeKwargs, BnbQuantizationConfig, ComputeEnvironment, CustomDtype, @@ -57,6 +59,8 @@ SageMakerDistributedType, TensorInformation, TorchDynamoPlugin, + TERecipeKwargs, + MSAMPRecipeKwargs, add_model_config_to_megatron_parser, ) from .environment import ( @@ -77,6 +81,7 @@ ) from .imports import ( deepspeed_required, + torchao_required, get_ccl_version, is_4bit_bnb_available, is_8bit_bnb_available, @@ -114,6 +119,7 @@ is_tensorboard_available, is_timm_available, is_torch_xla_available, + is_torchao_available, is_torchdata_available, is_torchdata_stateful_dataloader_available, is_torchvision_available, diff --git a/src/accelerate/utils/ao.py b/src/accelerate/utils/ao.py new file mode 100644 index 00000000000..1d21738c495 --- /dev/null +++ b/src/accelerate/utils/ao.py @@ -0,0 +1,112 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Needed utilities for torchao FP8 training. +""" + +from functools import partial + +import torch + +from .imports import torchao_required + + +def find_first_last_linear_layers(model: torch.nn.Module): + """ + Finds the first and last linear layer names in a model. + + This is needed during FP8 to avoid issues with + instability by keeping the first and last layers + unquantized. + + Ref: https://x.com/xariusrke/status/1826669142604141052 + """ + first_linear, last_linear = None, None + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear): + if first_linear is None: + first_linear = name + last_linear = name + return first_linear, last_linear + + +def filter_linear_layers(module, layer_name, first_layer_name, last_layer_name) -> bool: + """ + A function which will check if `module` is: + - a `torch.nn.Linear` layer + - has in_features and out_features divisible by 16 + - is not the first or last layer of the model. + + Args: + module (`torch.nn.Module`): + The module to check. + layer_name (`str`): + The fully qualified name of the layer. + first_layer_name (`str`): + The name of the first layer of the model. + last_layer_name (`str`): + The name of the last layer of the model. + """ + if isinstance(module, torch.nn.Linear): + if module.in_features % 16 != 0 or module.out_features % 16 != 0: + return False + # For stability reasons, we skip the first and last linear layers + # Otherwise can lead to the model not training or converging properly + # TODO: apply this to all FP8 backends + if layer_name in (first_layer_name, last_layer_name): + return False + return True + + +@torchao_required +def convert_to_float8_training( + model: torch.nn.Module, + config=None, + module_filter_func=None, + ): + """ + Converts all `nn.Linear` layers in the model (except the first and last) + to torchao's `Float8Linear` layer inplace. + + Args: + model (`torch.nn.Module`): + The model to convert. + config (`torchao.float8.Float8LinearConfig`, *optional*): + The configuration for the FP8 training. Recommended to utilize + `torchao.float8.recipe_name_to_linear_config` to generate this. + In general, the default config should be sufficient. + module_filter_func (`Callable`, *optional*): + Optional function that must take in a module and layer name, + and returns a boolean indicating whether the module should be + converted to FP8. Defaults to `filter_linear_layers`. See + it for an example. + + Example: + + ```python + from accelerate.utils.ao import convert_to_float8_training + model = MyModel() + model.to("cuda") + convert_to_float8_training(model) + + model.train() + ``` + """ + from torchao.float8 import convert_to_float8_training + + first_linear, last_linear = find_first_last_linear_layers(model) + if module_filter_func is None: + module_filter_func = partial(filter_linear_layers, first_layer_name=first_linear, last_layer_name=last_linear) + convert_to_float8_training(model, config, module_filter_func) diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index 39e048a6039..f1240b7aeb8 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -22,13 +22,15 @@ import functools import os import warnings +import logging from contextlib import contextmanager from dataclasses import dataclass, field from datetime import timedelta -from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union, get_args +from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union, get_args, TYPE_CHECKING import torch +from .ao import filter_linear_layers from .constants import ( FSDP_AUTO_WRAP_POLICY, FSDP_BACKWARD_PREFETCH, @@ -47,6 +49,12 @@ ) from .versions import compare_versions, is_torch_version +if TYPE_CHECKING: + # Mock imports for type checking + from torchao.float8 import Float8LinearConfig + +logger = logging.getLogger(__name__) + class KwargsHandler: """ @@ -279,40 +287,57 @@ def __post_init__(self): AmaxComputeAlgorithm = Literal["max", "most_recent"] +# FP8 training recipe kwargs +@dataclass +class AORecipeKwargs(KwargsHandler): + """ + Use this object in your [`Accelerator`] to customize the initialization of the recipe for FP8 mixed precision + training with `torchao` FP8. + + Args: + recipe_name (`str`, *optional*, default to `None`): + The name of the recipe to use for FP8 training. Should + be compatible with `torchao.float8.recipe_name_to_linear_config`. + config (`torchao.float8.Float8LinearConfig`, *optional*, default to `None`): + The configuration for the FP8 training. In general, the default config + should be sufficient. + module_filter_func (`Callable`, *optional*, default to `None`): + Optional function that must take in a module and layer name, + and returns a boolean indicating whether the module should be + converted to FP8. Defaults to `accelerate.utils.ao.filter_linear_layers`. See + it for an example. + """ + recipe_name: str = None + config: "Float8LinearConfig" = None + module_filter_func: Callable = None + + def __post_init__(self): + if self.module_filter_func is None: + self.module_filter_func = filter_linear_layers + + @dataclass -class FP8RecipeKwargs(KwargsHandler): +class TERecipeKwargs(KwargsHandler): """ Use this object in your [`Accelerator`] to customize the initialization of the recipe for FP8 mixed precision - training with `transformer-engine` or `ms-amp`. + training with `transformer-engine`. - For more information on `transformer-engine` args, please refer to the API + For more information on the args, please refer to the API [documentation](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html). - For more information on the `ms-amp` args, please refer to the Optimization Level - [documentation](https://azure.github.io/MS-AMP/docs/user-tutorial/optimization-level). - ```python from accelerate import Accelerator - from accelerate.utils import FP8RecipeKwargs + from accelerate.utils import TERecipeKwargs - kwargs = FP8RecipeKwargs(backend="te", fp8_format="HYBRID") + kwargs = TERecipeKwargs(fp8_format="HYBRID") accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=[kwargs]) ``` - To use MS-AMP as an engine, pass `backend="msamp"` and the `optimization_level`: - - ```python - kwargs = FP8RecipeKwargs(backend="msamp", optimization_level="02") - ``` - Args: - backend (`str`, *optional*): - Which FP8 engine to use. Must be one of `"msamp"` (MS-AMP) or `"te"` (TransformerEngine). If not passed, - will use whichever is available in the environment, prioritizing MS-AMP. use_autocast_during_eval (`bool`, *optional*, default to `False`): Whether to use FP8 autocast during eval mode. Generally better metrics are found when this is `False`. margin (`int`, *optional*, default to 0): @@ -328,21 +353,8 @@ class FP8RecipeKwargs(KwargsHandler): The algorithm to use for the scaling factor computation. Must be one of `max` or `most_recent`. override_linear_precision (`tuple` of three `bool`, *optional*, default to `(False, False, False)`): Whether or not to execute `fprop`, `dgrad`, and `wgrad` GEMMS in higher precision. - optimization_level (`str`), one of `O1`, `O2`. (default is `O2`): - What level of 8-bit collective communication should be used with MS-AMP. In general: - * O1: Weight gradients and `all_reduce` communications are done in fp8, reducing GPU - memory usage and communication bandwidth - * O2: First-order optimizer states are in 8-bit, and second order states are in FP16. - Only available when using Adam or AdamW. This maintains accuracy and can potentially save the - highest memory. - * 03: Specifically for DeepSpeed, implements capabilities so weights and master weights of models - are stored in FP8. If `fp8` is selected and deepspeed is enabled, will be used by default. (Not - available currently). """ - - backend: Backend = None use_autocast_during_eval: bool = None - opt_level: OptLevel = None margin: int = None interval: int = None fp8_format: FP8Format = None @@ -352,50 +364,74 @@ class FP8RecipeKwargs(KwargsHandler): def __post_init__(self): env_prefix = "ACCELERATE_FP8_" + if not is_transformer_engine_available(): + raise ImportError( + "TransformerEngine is not available. Please install it or use a different backend." + ) + if self.use_autocast_during_eval is None: + self.use_autocast_during_eval = parse_flag_from_env(env_prefix + "USE_AUTOCAST_DURING_EVAL") + if self.margin is None: + self.margin = int(os.environ.get(env_prefix + "MARGIN", 0)) + if self.interval is None: + self.interval = int(os.environ.get(env_prefix + "INTERVAL", 1)) + if self.fp8_format is None: + self.fp8_format = os.environ.get(env_prefix + "FORMAT", "HYBRID") + self.fp8_format = self.fp8_format.upper() + if self.fp8_format not in get_args(FP8Format): + raise ValueError(f"`fp8_format` must be one of {' or '.join(get_args(FP8Format))}.") + if self.amax_compute_algo is None: + self.amax_compute_algo = os.environ.get(env_prefix + "AMAX_COMPUTE_ALGO", "most_recent") + self.amax_compute_algo = self.amax_compute_algo.lower() + if self.amax_compute_algo not in get_args(AmaxComputeAlgorithm): + raise ValueError(f"`amax_compute_algo` must be one of {' or '.join(get_args(AmaxComputeAlgorithm))}") + if self.amax_history_len is None: + self.amax_history_len = int(os.environ.get(env_prefix + "AMAX_HISTORY_LEN", 1024)) + if self.override_linear_precision is None: + fprop = parse_flag_from_env(env_prefix + "OVERRIDE_FPROP") + dgrad = parse_flag_from_env(env_prefix + "OVERRIDE_DGRAD") + wgrad = parse_flag_from_env(env_prefix + "OVERRIDE_WGRAD") + self.override_linear_precision = (fprop, dgrad, wgrad) + + +@dataclass +class MSAMPRecipeKwargs(KwargsHandler): + """ + Use this object in your [`Accelerator`] to customize the initialization of the recipe for FP8 mixed precision + training with `ms-amp`. + """ + opt_level: OptLevel = None + + def __post_init__(self): + env_prefix = "ACCELERATE_FP8_" + if self.opt_level is None: + self.opt_level = os.environ.get(env_prefix + "OPT_LEVEL", "O2") + if self.opt_level not in get_args(OptLevel): + raise ValueError(f"`opt_level` must be one of {' or '.join(get_args(OptLevel))}") + + +@dataclass +class FP8RecipeKwargs(TERecipeKwargs, MSAMPRecipeKwargs): + """ + Deprecated. Please use one of the proper FP8 recipe + kwargs classes such as `TERecipeKwargs` or `MSAMPRecipeKwargs` + instead. + """ + + backend: Backend = None + + def __post_init__(self): + env_prefix = "ACCELERATE_FP8_" + warnings.warn( + "FP8RecipeKwargs is deprecated and will be removed in Accelerate v2.0.0. " + "Please use one of the proper FP8 recipe kwargs classes such as TERecipeKwargs or MSAMPRecipeKwargs instead.", + FutureWarning, + ) default_backend = "msamp" if is_msamp_available() else "te" if self.backend is None: self.backend = os.environ.get(env_prefix + "BACKEND", default_backend) self.backend = self.backend.upper() if self.backend not in get_args(Backend): - raise ValueError("`backend` must be 'MSAMP' or 'TE' (TransformerEngine).") - # Check TE args - if self.backend == "TE": - if not is_transformer_engine_available(): - raise ValueError( - "TransformerEngine is not available. Please either install it, or use the 'MSAMP' backend (if installed)." - ) - if self.use_autocast_during_eval is None: - self.use_autocast_during_eval = parse_flag_from_env(env_prefix + "USE_AUTOCAST_DURING_EVAL") - if self.margin is None: - self.margin = int(os.environ.get(env_prefix + "MARGIN", 0)) - if self.interval is None: - self.interval = int(os.environ.get(env_prefix + "INTERVAL", 1)) - if self.fp8_format is None: - self.fp8_format = os.environ.get(env_prefix + "FORMAT", "HYBRID") - self.fp8_format = self.fp8_format.upper() - if self.fp8_format not in get_args(FP8Format): - raise ValueError(f"`fp8_format` must be one of {' or '.join(get_args(FP8Format))}.") - if self.amax_compute_algo is None: - self.amax_compute_algo = os.environ.get(env_prefix + "AMAX_COMPUTE_ALGO", "most_recent") - self.amax_compute_algo = self.amax_compute_algo.lower() - if self.amax_compute_algo not in get_args(AmaxComputeAlgorithm): - raise ValueError(f"`amax_compute_algo` must be one of {' or '.join(get_args(AmaxComputeAlgorithm))}") - if self.amax_history_len is None: - self.amax_history_len = int(os.environ.get(env_prefix + "AMAX_HISTORY_LEN", 1024)) - if self.override_linear_precision is None: - fprop = parse_flag_from_env(env_prefix + "OVERRIDE_FPROP") - dgrad = parse_flag_from_env(env_prefix + "OVERRIDE_DGRAD") - wgrad = parse_flag_from_env(env_prefix + "OVERRIDE_WGRAD") - self.override_linear_precision = (fprop, dgrad, wgrad) - elif self.backend == "MSAMP": - if not is_msamp_available(): - raise ValueError( - "MS-AMP is not available. Please either install it, or use the 'TE' backend (if installed)." - ) - if self.opt_level is None: - self.opt_level = os.environ.get(env_prefix + "OPT_LEVEL", "O2") - if self.opt_level not in get_args(OptLevel): - raise ValueError(f"`optimization_level` must be one of {' or '.join(get_args(OptLevel))}") + raise ValueError("`backend` must be 'MSAMP' or 'TE' (TransformerEngine) to use `FP8RecipeKwargs`.") # Literal diff --git a/src/accelerate/utils/imports.py b/src/accelerate/utils/imports.py index aeafe91cf3c..3ba98691902 100644 --- a/src/accelerate/utils/imports.py +++ b/src/accelerate/utils/imports.py @@ -142,6 +142,10 @@ def is_torch_xla_available(check_is_tpu=False, check_is_gpu=False): return True +def is_torchao_available(): + return _is_package_available("torchao") + + def is_deepspeed_available(): if is_mlu_available(): return _is_package_available("deepspeed", metadata_name="deepspeed-mlu") @@ -423,6 +427,22 @@ def is_torchdata_stateful_dataloader_available(): return False +def torchao_required(func): + """ + A decorator that ensures the decorated function is only called when torchao is available. + """ + + @wraps(func) + def wrapper(*args, **kwargs): + if not is_torchao_available(): + raise ImportError( + "`torchao` is not available, please install it before calling this function via `pip install torchao`." + ) + return func(*args, **kwargs) + + return wrapper + + # TODO: Rework this into `utils.deepspeed` and migrate the "core" chunks into `accelerate.deepspeed` def deepspeed_required(func): """ From fbeb5a7bc27c020a0c8e30ebad44fc6fe48571b9 Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Thu, 16 Jan 2025 10:55:29 -0500 Subject: [PATCH 5/9] Clean --- benchmarks/fp8/torchao/non_distributed.py | 35 +++++++++++++---------- src/accelerate/accelerator.py | 24 ++++++++++------ src/accelerate/utils/__init__.py | 8 +++--- src/accelerate/utils/ao.py | 28 ++++++++---------- src/accelerate/utils/dataclasses.py | 34 +++++++++------------- src/accelerate/utils/imports.py | 2 +- 6 files changed, 67 insertions(+), 64 deletions(-) diff --git a/benchmarks/fp8/torchao/non_distributed.py b/benchmarks/fp8/torchao/non_distributed.py index 81eb0d2bc73..e2426f162d5 100644 --- a/benchmarks/fp8/torchao/non_distributed.py +++ b/benchmarks/fp8/torchao/non_distributed.py @@ -18,9 +18,10 @@ This particular script verifies this for single GPU training. """ +from functools import partial + import evaluate import torch -from functools import partial from datasets import load_dataset from torch.optim import AdamW from torch.utils.data import DataLoader @@ -28,6 +29,7 @@ from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup from accelerate import Accelerator +from accelerate.state import AcceleratorState from accelerate.utils import AORecipeKwargs, set_seed @@ -169,8 +171,10 @@ def train_baseline(): def train_integration(): set_seed(42) - model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME) accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=[AORecipeKwargs()]) + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities( + MODEL_NAME, accelerator=accelerator + ) model = accelerator.prepare(model) base_model_results = evaluate_model(model, eval_dataloader, METRIC) model.train() @@ -196,17 +200,18 @@ def train_integration(): if __name__ == "__main__": - # baseline_not_trained, baseline_trained = train_baseline() + baseline_not_trained, baseline_trained = train_baseline() + AcceleratorState._reset_state(True) accelerator_not_trained, accelerator_trained = train_integration() - # assert ( - # baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] - # ), f'Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' - # assert ( - # baseline_not_trained["f1"] == accelerator_not_trained["f1"] - # ), f'F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' - # assert ( - # baseline_trained["accuracy"] == accelerator_trained["accuracy"] - # ), f'Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' - # assert ( - # baseline_trained["f1"] == accelerator_trained["f1"] - # ), f'F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' + assert ( + baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] + ), f'Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' + assert ( + baseline_not_trained["f1"] == accelerator_not_trained["f1"] + ), f'F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' + assert ( + baseline_trained["accuracy"] == accelerator_trained["accuracy"] + ), f'Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' + assert ( + baseline_trained["f1"] == accelerator_trained["f1"] + ), f'F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 923f2693764..d3ba6f4f74a 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -29,11 +29,12 @@ from types import MethodType from typing import Any, Callable, Union -from accelerate.utils.imports import is_torchao_available import torch import torch.utils.hooks as hooks from huggingface_hub import split_torch_state_dict_into_shards +from accelerate.utils.imports import is_torchao_available + from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state from .data_loader import DataLoaderDispatcher, prepare_data_loader, skip_first_batches from .logging import get_logger @@ -49,10 +50,8 @@ WEIGHTS_INDEX_NAME, WEIGHTS_NAME, WEIGHTS_PATTERN_NAME, - AutocastKwargs, AORecipeKwargs, - TERecipeKwargs, - MSAMPRecipeKwargs, + AutocastKwargs, DataLoaderConfiguration, DeepSpeedPlugin, DistributedDataParallelKwargs, @@ -66,18 +65,20 @@ KwargsHandler, LoggerType, MegatronLMPlugin, + MSAMPRecipeKwargs, PrecisionType, ProfileKwargs, ProjectConfiguration, RNGType, + TERecipeKwargs, TorchDynamoPlugin, apply_fp8_autowrap, check_os_kernel, clean_state_dict_for_safetensors, compare_versions, convert_model, - convert_to_float8_training, convert_outputs_to_fp32, + convert_to_float8_training, ensure_weights_retied, extract_model_from_parallel, gather, @@ -442,7 +443,9 @@ def __init__( elif is_msamp_available(): self.msamp_recipe_handler = MSAMPRecipeKwargs() else: - raise ImportError("Tried to train with `fp8` and auto-detect backend, but no FP8-compatible backend was installed.") + raise ImportError( + "Tried to train with `fp8` and auto-detect backend, but no FP8-compatible backend was installed." + ) self.delayed_fp8_autocast = False if self.has_fp8_handler: @@ -1627,8 +1630,13 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e def _prepare_ao(self, *args): if not is_torchao_available(): raise ImportError("`torchao` was not found on your system. Please ensure that `torchao` is installed") - for model in self._models: - convert_to_float8_training(model, config=self.ao_recipe_handler.config, module_filter_func=self.ao_recipe_handler.module_filter_func) + for arg in args: + if isinstance(arg, torch.nn.Module): + convert_to_float8_training( + arg, + config=self.ao_recipe_handler.config, + module_filter_func=self.ao_recipe_handler.module_filter_func, + ) return args def _prepare_te(self, *args): diff --git a/src/accelerate/utils/__init__.py b/src/accelerate/utils/__init__.py index 6219adbfee0..502c9c04b88 100644 --- a/src/accelerate/utils/__init__.py +++ b/src/accelerate/utils/__init__.py @@ -33,8 +33,8 @@ XPU_PROFILING_AVAILABLE_PYTORCH_VERSION, ) from .dataclasses import ( - AutocastKwargs, AORecipeKwargs, + AutocastKwargs, BnbQuantizationConfig, ComputeEnvironment, CustomDtype, @@ -52,15 +52,15 @@ KwargsHandler, LoggerType, MegatronLMPlugin, + MSAMPRecipeKwargs, PrecisionType, ProfileKwargs, ProjectConfiguration, RNGType, SageMakerDistributedType, TensorInformation, - TorchDynamoPlugin, TERecipeKwargs, - MSAMPRecipeKwargs, + TorchDynamoPlugin, add_model_config_to_megatron_parser, ) from .environment import ( @@ -81,7 +81,6 @@ ) from .imports import ( deepspeed_required, - torchao_required, get_ccl_version, is_4bit_bnb_available, is_8bit_bnb_available, @@ -129,6 +128,7 @@ is_wandb_available, is_weights_only_available, is_xpu_available, + torchao_required, ) from .modeling import ( align_module_device, diff --git a/src/accelerate/utils/ao.py b/src/accelerate/utils/ao.py index 1d21738c495..e0a2cf93d73 100644 --- a/src/accelerate/utils/ao.py +++ b/src/accelerate/utils/ao.py @@ -27,9 +27,7 @@ def find_first_last_linear_layers(model: torch.nn.Module): """ Finds the first and last linear layer names in a model. - This is needed during FP8 to avoid issues with - instability by keeping the first and last layers - unquantized. + This is needed during FP8 to avoid issues with instability by keeping the first and last layers unquantized. Ref: https://x.com/xariusrke/status/1826669142604141052 """ @@ -72,31 +70,29 @@ def filter_linear_layers(module, layer_name, first_layer_name, last_layer_name) @torchao_required def convert_to_float8_training( - model: torch.nn.Module, - config=None, - module_filter_func=None, - ): + model: torch.nn.Module, + config=None, + module_filter_func=None, +): """ - Converts all `nn.Linear` layers in the model (except the first and last) - to torchao's `Float8Linear` layer inplace. + Converts all `nn.Linear` layers in the model (except the first and last) to torchao's `Float8Linear` layer inplace. Args: model (`torch.nn.Module`): The model to convert. config (`torchao.float8.Float8LinearConfig`, *optional*): The configuration for the FP8 training. Recommended to utilize - `torchao.float8.recipe_name_to_linear_config` to generate this. - In general, the default config should be sufficient. + `torchao.float8.recipe_name_to_linear_config` to generate this. In general, the default config should be + sufficient. module_filter_func (`Callable`, *optional*): - Optional function that must take in a module and layer name, - and returns a boolean indicating whether the module should be - converted to FP8. Defaults to `filter_linear_layers`. See - it for an example. + Optional function that must take in a module and layer name, and returns a boolean indicating whether the + module should be converted to FP8. Defaults to `filter_linear_layers`. See it for an example. Example: ```python from accelerate.utils.ao import convert_to_float8_training + model = MyModel() model.to("cuda") convert_to_float8_training(model) @@ -109,4 +105,4 @@ def convert_to_float8_training( first_linear, last_linear = find_first_last_linear_layers(model) if module_filter_func is None: module_filter_func = partial(filter_linear_layers, first_layer_name=first_linear, last_layer_name=last_linear) - convert_to_float8_training(model, config, module_filter_func) + convert_to_float8_training(model, module_filter_fn=module_filter_func, config=config) diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index f1240b7aeb8..9aaf35308ff 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -20,17 +20,16 @@ import copy import enum import functools +import logging import os import warnings -import logging from contextlib import contextmanager from dataclasses import dataclass, field from datetime import timedelta -from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union, get_args, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union, get_args import torch -from .ao import filter_linear_layers from .constants import ( FSDP_AUTO_WRAP_POLICY, FSDP_BACKWARD_PREFETCH, @@ -49,6 +48,7 @@ ) from .versions import compare_versions, is_torch_version + if TYPE_CHECKING: # Mock imports for type checking from torchao.float8 import Float8LinearConfig @@ -296,25 +296,20 @@ class AORecipeKwargs(KwargsHandler): Args: recipe_name (`str`, *optional*, default to `None`): - The name of the recipe to use for FP8 training. Should - be compatible with `torchao.float8.recipe_name_to_linear_config`. + The name of the recipe to use for FP8 training. Should be compatible with + `torchao.float8.recipe_name_to_linear_config`. config (`torchao.float8.Float8LinearConfig`, *optional*, default to `None`): - The configuration for the FP8 training. In general, the default config - should be sufficient. + The configuration for the FP8 training. In general, the default config should be sufficient. module_filter_func (`Callable`, *optional*, default to `None`): - Optional function that must take in a module and layer name, - and returns a boolean indicating whether the module should be - converted to FP8. Defaults to `accelerate.utils.ao.filter_linear_layers`. See - it for an example. + Optional function that must take in a module and layer name, and returns a boolean indicating whether the + module should be converted to FP8. Defaults to `accelerate.utils.ao.filter_linear_layers`. See it for an + example. """ + recipe_name: str = None config: "Float8LinearConfig" = None module_filter_func: Callable = None - def __post_init__(self): - if self.module_filter_func is None: - self.module_filter_func = filter_linear_layers - @dataclass class TERecipeKwargs(KwargsHandler): @@ -354,6 +349,7 @@ class TERecipeKwargs(KwargsHandler): override_linear_precision (`tuple` of three `bool`, *optional*, default to `(False, False, False)`): Whether or not to execute `fprop`, `dgrad`, and `wgrad` GEMMS in higher precision. """ + use_autocast_during_eval: bool = None margin: int = None interval: int = None @@ -365,9 +361,7 @@ class TERecipeKwargs(KwargsHandler): def __post_init__(self): env_prefix = "ACCELERATE_FP8_" if not is_transformer_engine_available(): - raise ImportError( - "TransformerEngine is not available. Please install it or use a different backend." - ) + raise ImportError("TransformerEngine is not available. Please install it or use a different backend.") if self.use_autocast_during_eval is None: self.use_autocast_during_eval = parse_flag_from_env(env_prefix + "USE_AUTOCAST_DURING_EVAL") if self.margin is None: @@ -399,6 +393,7 @@ class MSAMPRecipeKwargs(KwargsHandler): Use this object in your [`Accelerator`] to customize the initialization of the recipe for FP8 mixed precision training with `ms-amp`. """ + opt_level: OptLevel = None def __post_init__(self): @@ -412,8 +407,7 @@ def __post_init__(self): @dataclass class FP8RecipeKwargs(TERecipeKwargs, MSAMPRecipeKwargs): """ - Deprecated. Please use one of the proper FP8 recipe - kwargs classes such as `TERecipeKwargs` or `MSAMPRecipeKwargs` + Deprecated. Please use one of the proper FP8 recipe kwargs classes such as `TERecipeKwargs` or `MSAMPRecipeKwargs` instead. """ diff --git a/src/accelerate/utils/imports.py b/src/accelerate/utils/imports.py index 3ba98691902..7653f36e60d 100644 --- a/src/accelerate/utils/imports.py +++ b/src/accelerate/utils/imports.py @@ -110,7 +110,7 @@ def is_lomo_available(): def is_fp8_available(): - return is_msamp_available() or is_transformer_engine_available() + return is_msamp_available() or is_transformer_engine_available() or is_torchao_available() def is_cuda_available(): From 3820c40c20d009624db3369953ceefb304e6f2ba Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Fri, 17 Jan 2025 08:57:08 -0500 Subject: [PATCH 6/9] DDP varient working --- benchmarks/fp8/torchao/ddp.py | 126 +++++++++++++++++++++++----------- 1 file changed, 85 insertions(+), 41 deletions(-) diff --git a/benchmarks/fp8/torchao/ddp.py b/benchmarks/fp8/torchao/ddp.py index ba708a27be4..873f13918c5 100644 --- a/benchmarks/fp8/torchao/ddp.py +++ b/benchmarks/fp8/torchao/ddp.py @@ -13,69 +13,116 @@ # limitations under the License. """ -This script tests to ensure that `accelerate` performs at the same level as raw `TransformersEngine`. +This script tests to ensure that `accelerate` performs at the same level as raw `torchao`. This particular script verifies this for DDP training. """ +from functools import partial + import evaluate import torch -import transformer_engine.common.recipe as te_recipe -import transformer_engine.pytorch as te -from fp8_utils import evaluate_model, get_named_parameters, get_training_utilities +from datasets import load_dataset from torch.nn.parallel import DistributedDataParallel as DDP -from transformer_engine.common.recipe import DelayedScaling +from torch.optim import AdamW +from torch.utils.data import DataLoader +from torchao.float8 import convert_to_float8_training +from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup from accelerate import Accelerator from accelerate.state import AcceleratorState -from accelerate.utils import FP8RecipeKwargs, set_seed -from accelerate.utils.transformer_engine import convert_model +from accelerate.utils import AORecipeKwargs, set_seed + +from fp8_utils import get_dataloaders MODEL_NAME = "bert-base-cased" METRIC = evaluate.load("glue", "mrpc") +def get_training_utilities(model_name: str, batch_size: int = 16, accelerator=None): + """ + Returns a tuple of: + - Model + - Optimizer + - Train dataloader (prepared) + - Eval dataloader (prepared) + - LR Scheduler + Suitable for training on the MRPC dataset + """ + + if accelerator is None: + accelerator = Accelerator() + model = AutoModelForSequenceClassification.from_pretrained(model_name) + train_dataloader, eval_dataloader = get_dataloaders(model_name, batch_size) + optimizer = AdamW(model.parameters(), lr=0.0001) + lr_scheduler = get_linear_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=100, + num_training_steps=len(train_dataloader) * 2, + ) + train_dataloader, eval_dataloader = accelerator.prepare(train_dataloader, eval_dataloader) + return model, optimizer, train_dataloader, eval_dataloader, lr_scheduler + + +def evaluate_model(model, dataloader, metric, accelerator=None): + "Turns model to .eval(), runs dataloader, calculates metric, then turns eval back on" + model.eval() + for step, batch in enumerate(dataloader): + with torch.no_grad(): + outputs = model(**batch) + predictions = outputs.logits.argmax(dim=-1) + references = batch["labels"] + if accelerator is not None and accelerator.num_processes > 1: + predictions, references = accelerator.gather_for_metrics((predictions, references)) + metric.add_batch(predictions=predictions, references=references) + return metric.compute() + + +def filter_linear_layers(module, fqn, first_layer_name=None, last_layer_name=None): + if isinstance(module, torch.nn.Linear): + if module.in_features % 16 != 0 or module.out_features % 16 != 0: + return False + # For stability reasons, we skip the first and last linear layers + # Otherwise can lead to the model not training or converging properly + if fqn in (first_layer_name, last_layer_name): + return False + return True + + def train_baseline(): set_seed(42) model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME) + first_linear = None + last_linear = None + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear): + if first_linear is None: + first_linear = name + last_linear = name + func = partial(filter_linear_layers, first_layer_name=first_linear, last_layer_name=last_linear) accelerator = Accelerator() device = accelerator.device model.to(device) - # Convert the model to TE - old_named_params = get_named_parameters(model) - - with torch.no_grad(): - convert_model(model) - - FP8_RECIPE_KWARGS = {"fp8_format": te_recipe.Format.HYBRID, "amax_history_len": 32, "amax_compute_algo": "max"} - fp8_recipe = DelayedScaling(**FP8_RECIPE_KWARGS) - - new_named_params = get_named_parameters(model) + convert_to_float8_training(model, module_filter_fn=func) # Convert the model to DDP device_ids, output_device = [accelerator.local_process_index], accelerator.local_process_index model = DDP(model, device_ids=device_ids, output_device=output_device) - mapping = {p: new_named_params[n] for n, p in old_named_params.items()} - for param_group in optimizer.param_groups: - param_group["params"] = [mapping[p] for p in param_group["params"]] - base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) model.train() - for _ in range(2): - for batch in train_dataloader: - with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - batch = batch.to(device) - outputs = model(**batch) + for batch in train_dataloader: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + batch = batch.to(device) + outputs = model(**batch) loss = outputs.loss loss.backward() - optimizer.step() - optimizer.zero_grad() - lr_scheduler.step() + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) @@ -90,10 +137,8 @@ def train_baseline(): def train_integration(): - FP8_RECIPE_KWARGS = {"fp8_format": "HYBRID", "amax_history_len": 32, "amax_compute_algo": "max"} - kwargs_handlers = [FP8RecipeKwargs(backend="TE", **FP8_RECIPE_KWARGS)] AcceleratorState()._reset_state(True) - accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=kwargs_handlers) + accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=[AORecipeKwargs()]) set_seed(42) model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities( MODEL_NAME, accelerator=accelerator @@ -103,14 +148,13 @@ def train_integration(): base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) model.train() - for _ in range(2): - for batch in train_dataloader: - outputs = model(**batch) - loss = outputs.loss - accelerator.backward(loss) - optimizer.step() - optimizer.zero_grad() - lr_scheduler.step() + for batch in train_dataloader: + outputs = model(**batch) + loss = outputs.loss + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) From 4dfa816fa68436806adf9869c9ded87ca5a2379d Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Fri, 17 Jan 2025 09:03:49 -0500 Subject: [PATCH 7/9] FSDP as well --- benchmarks/fp8/torchao/ddp.py | 9 +-- benchmarks/fp8/torchao/fsdp.py | 127 +++++++++++++++++++++------------ 2 files changed, 86 insertions(+), 50 deletions(-) diff --git a/benchmarks/fp8/torchao/ddp.py b/benchmarks/fp8/torchao/ddp.py index 873f13918c5..0b7e6071ac2 100644 --- a/benchmarks/fp8/torchao/ddp.py +++ b/benchmarks/fp8/torchao/ddp.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,19 +22,16 @@ import evaluate import torch -from datasets import load_dataset +from fp8_utils import get_dataloaders from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import AdamW -from torch.utils.data import DataLoader from torchao.float8 import convert_to_float8_training -from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup +from transformers import AutoModelForSequenceClassification, get_linear_schedule_with_warmup from accelerate import Accelerator from accelerate.state import AcceleratorState from accelerate.utils import AORecipeKwargs, set_seed -from fp8_utils import get_dataloaders - MODEL_NAME = "bert-base-cased" METRIC = evaluate.load("glue", "mrpc") diff --git a/benchmarks/fp8/torchao/fsdp.py b/benchmarks/fp8/torchao/fsdp.py index 418122185e1..a047f27bd86 100644 --- a/benchmarks/fp8/torchao/fsdp.py +++ b/benchmarks/fp8/torchao/fsdp.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +13,7 @@ # limitations under the License. """ -This script tests to ensure that `accelerate` performs at the same level as raw `TransformersEngine`. +This script tests to ensure that `accelerate` performs at the same level as raw `torchao`. This particular script verifies this for FSDP training. """ @@ -22,20 +22,19 @@ import evaluate import torch -import transformer_engine.common.recipe as te_recipe -import transformer_engine.pytorch as te -from fp8_utils import evaluate_model, get_named_parameters, get_training_utilities +from fp8_utils import get_dataloaders from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import MixedPrecision from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy -from transformer_engine.common.recipe import DelayedScaling +from torch.optim import AdamW +from torchao.float8 import convert_to_float8_training +from transformers import AutoModelForSequenceClassification, get_linear_schedule_with_warmup from transformers.models.bert import BertLayer from accelerate import Accelerator from accelerate import FullyShardedDataParallelPlugin as FSDPPlugin from accelerate.state import AcceleratorState -from accelerate.utils import FP8RecipeKwargs, set_seed -from accelerate.utils.transformer_engine import convert_model +from accelerate.utils import AORecipeKwargs, set_seed MODEL_NAME = "bert-base-cased" @@ -44,23 +43,72 @@ FSDP_WRAP_POLICY = partial(transformer_auto_wrap_policy, transformer_layer_cls={BertLayer}) +def get_training_utilities(model_name: str, batch_size: int = 16, accelerator=None): + """ + Returns a tuple of: + - Model + - Optimizer + - Train dataloader (prepared) + - Eval dataloader (prepared) + - LR Scheduler + Suitable for training on the MRPC dataset + """ + + if accelerator is None: + accelerator = Accelerator() + model = AutoModelForSequenceClassification.from_pretrained(model_name) + train_dataloader, eval_dataloader = get_dataloaders(model_name, batch_size) + optimizer = AdamW(model.parameters(), lr=0.0001) + lr_scheduler = get_linear_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=100, + num_training_steps=len(train_dataloader) * 2, + ) + train_dataloader, eval_dataloader = accelerator.prepare(train_dataloader, eval_dataloader) + return model, optimizer, train_dataloader, eval_dataloader, lr_scheduler + + +def filter_linear_layers(module, fqn, first_layer_name=None, last_layer_name=None): + if isinstance(module, torch.nn.Linear): + if module.in_features % 16 != 0 or module.out_features % 16 != 0: + return False + # For stability reasons, we skip the first and last linear layers + # Otherwise can lead to the model not training or converging properly + if fqn in (first_layer_name, last_layer_name): + return False + return True + + +def evaluate_model(model, dataloader, metric, accelerator=None): + "Turns model to .eval(), runs dataloader, calculates metric, then turns eval back on" + model.eval() + for step, batch in enumerate(dataloader): + with torch.no_grad(): + outputs = model(**batch) + predictions = outputs.logits.argmax(dim=-1) + references = batch["labels"] + if accelerator is not None and accelerator.num_processes > 1: + predictions, references = accelerator.gather_for_metrics((predictions, references)) + metric.add_batch(predictions=predictions, references=references) + return metric.compute() + + def train_baseline(): set_seed(42) model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME) + first_linear = None + last_linear = None + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear): + if first_linear is None: + first_linear = name + last_linear = name + func = partial(filter_linear_layers, first_layer_name=first_linear, last_layer_name=last_linear) accelerator = Accelerator() device = accelerator.device model.to(device) - # Convert the model to TE - old_named_params = get_named_parameters(model) - - with torch.no_grad(): - convert_model(model) - - FP8_RECIPE_KWARGS = {"fp8_format": te_recipe.Format.HYBRID, "amax_history_len": 32, "amax_compute_algo": "max"} - fp8_recipe = DelayedScaling(**FP8_RECIPE_KWARGS) - - new_named_params = get_named_parameters(model) + convert_to_float8_training(model, module_filter_fn=func) # Convert the model to FSDP model = FSDP( @@ -70,24 +118,18 @@ def train_baseline(): auto_wrap_policy=FSDP_WRAP_POLICY, ) - mapping = {p: new_named_params[n] for n, p in old_named_params.items()} - for param_group in optimizer.param_groups: - param_group["params"] = [mapping[p] for p in param_group["params"]] - base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) model.train() - for _ in range(2): - for batch in train_dataloader: - with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - batch = batch.to(device) - outputs = model(**batch) - loss = outputs.loss - loss.backward() - optimizer.step() - optimizer.zero_grad() - lr_scheduler.step() + for batch in train_dataloader: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + batch = batch.to(device) + outputs = model(**batch) + loss = outputs.loss + loss.backward() + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) @@ -102,15 +144,13 @@ def train_baseline(): def train_integration(): - FP8_RECIPE_KWARGS = {"fp8_format": "HYBRID", "amax_history_len": 32, "amax_compute_algo": "max"} - kwargs_handlers = [FP8RecipeKwargs(backend="TE", **FP8_RECIPE_KWARGS)] AcceleratorState()._reset_state(True) fsdp_plugin = FSDPPlugin( auto_wrap_policy=FSDP_WRAP_POLICY, use_orig_params=True, mixed_precision_policy=MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32), ) - accelerator = Accelerator(mixed_precision="fp8", fsdp_plugin=fsdp_plugin, kwargs_handlers=kwargs_handlers) + accelerator = Accelerator(mixed_precision="fp8", fsdp_plugin=fsdp_plugin, kwargs_handlers=[AORecipeKwargs()]) set_seed(42) model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities( MODEL_NAME, accelerator=accelerator @@ -120,14 +160,13 @@ def train_integration(): base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) model.train() - for _ in range(2): - for batch in train_dataloader: - outputs = model(**batch) - loss = outputs.loss - accelerator.backward(loss) - optimizer.step() - optimizer.zero_grad() - lr_scheduler.step() + for batch in train_dataloader: + outputs = model(**batch) + loss = outputs.loss + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) From 04c9f560139d5c4e1291420649cb4801a807f3a1 Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Fri, 17 Jan 2025 09:34:33 -0500 Subject: [PATCH 8/9] Works for all but zero3 --- benchmarks/fp8/torchao/distrib_deepspeed.py | 121 ++++++++++---------- 1 file changed, 63 insertions(+), 58 deletions(-) diff --git a/benchmarks/fp8/torchao/distrib_deepspeed.py b/benchmarks/fp8/torchao/distrib_deepspeed.py index e678deb3659..d8019524a10 100644 --- a/benchmarks/fp8/torchao/distrib_deepspeed.py +++ b/benchmarks/fp8/torchao/distrib_deepspeed.py @@ -13,31 +13,40 @@ # limitations under the License. """ -This script tests to ensure that `accelerate` performs at the same level as raw `TransformersEngine`. +This script tests to ensure that `accelerate` performs at the same level as raw `torchao`. -This particular script verifies this for DDP training. +This particular script verifies this for deepspeed training. """ from unittest.mock import patch +from functools import partial import deepspeed import evaluate import torch -import transformer_engine.common.recipe as te_recipe -import transformer_engine.pytorch as te from fp8_utils import evaluate_model, get_named_parameters, get_training_utilities -from transformer_engine.common.recipe import DelayedScaling from accelerate import Accelerator, DeepSpeedPlugin from accelerate.state import AcceleratorState -from accelerate.utils import FP8RecipeKwargs, set_seed -from accelerate.utils.transformer_engine import convert_model +from accelerate.utils import AORecipeKwargs, set_seed +from torchao.float8 import convert_to_float8_training MODEL_NAME = "bert-base-cased" METRIC = evaluate.load("glue", "mrpc") +def filter_linear_layers(module, fqn, first_layer_name=None, last_layer_name=None): + if isinstance(module, torch.nn.Linear): + if module.in_features % 16 != 0 or module.out_features % 16 != 0: + return False + # For stability reasons, we skip the first and last linear layers + # Otherwise can lead to the model not training or converging properly + if fqn in (first_layer_name, last_layer_name): + return False + return True + + def train_baseline(zero_stage: int = 1): # This forces transformers to think Zero-3 Init should be used with patch("transformers.integrations.deepspeed.is_deepspeed_zero3_enabled") as mock: @@ -49,19 +58,17 @@ def train_baseline(zero_stage: int = 1): MODEL_NAME, accelerator=accelerator ) - # Convert the model to TE - old_named_params = get_named_parameters(model) - - with torch.no_grad(): - convert_model(model) - new_named_params = get_named_parameters(model) - - mapping = {p: new_named_params[n] for n, p in old_named_params.items()} - for param_group in optimizer.param_groups: - param_group["params"] = [mapping[p] for p in param_group["params"]] + first_linear = None + last_linear = None + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear): + if first_linear is None: + first_linear = name + last_linear = name + func = partial(filter_linear_layers, first_layer_name=first_linear, last_layer_name=last_linear) + convert_to_float8_training(model, module_filter_fn=func) - FP8_RECIPE_KWARGS = {"fp8_format": te_recipe.Format.HYBRID, "amax_history_len": 32, "amax_compute_algo": "max"} - fp8_recipe = DelayedScaling(**FP8_RECIPE_KWARGS) + accelerator = Accelerator() import numpy as np @@ -99,17 +106,15 @@ def train_baseline(zero_stage: int = 1): model_outputs = [] data = [] - for _ in range(2): - for batch in train_dataloader: - with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): - outputs = model(**batch) - data.append(batch.to("cpu")) - model_outputs.append(outputs.logits.to("cpu")) - loss = outputs.loss - model.backward(loss) - model.step() - for _ in range(accelerator.num_processes): - lr_scheduler.step() + for batch in train_dataloader: + outputs = model(**batch) + data.append(batch.to("cpu")) + model_outputs.append(outputs.logits.to("cpu")) + loss = outputs.loss + model.backward(loss) + model.step() + for _ in range(accelerator.num_processes): + lr_scheduler.step() trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) model.destroy() @@ -125,15 +130,16 @@ def train_baseline(zero_stage: int = 1): def train_integration(zero_stage: int = 1): set_seed(42) - FP8_RECIPE_KWARGS = {"fp8_format": "HYBRID", "amax_history_len": 32, "amax_compute_algo": "max"} - kwargs_handlers = [FP8RecipeKwargs(backend="TE", **FP8_RECIPE_KWARGS)] AcceleratorState()._reset_state(True) + # This forces transformers to think Zero-3 Init should be used + with patch("transformers.integrations.deepspeed.is_deepspeed_zero3_enabled") as mock: + mock.return_value = zero_stage == 3 deepspeed_plugin = DeepSpeedPlugin( zero_stage=zero_stage, zero3_init_flag=zero_stage == 3, ) accelerator = Accelerator( - mixed_precision="fp8", kwargs_handlers=kwargs_handlers, deepspeed_plugin=deepspeed_plugin + mixed_precision="fp8", kwargs_handlers=[AORecipeKwargs()], deepspeed_plugin=deepspeed_plugin ) accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = 16 @@ -146,16 +152,15 @@ def train_integration(zero_stage: int = 1): model.train() model_outputs = [] data = [] - for _ in range(2): - for batch in train_dataloader: - outputs = model(**batch) - data.append(batch.to("cpu")) - model_outputs.append(outputs.logits.to("cpu")) - loss = outputs.loss - accelerator.backward(loss) - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() + for batch in train_dataloader: + outputs = model(**batch) + data.append(batch.to("cpu")) + model_outputs.append(outputs.logits.to("cpu")) + loss = outputs.loss + accelerator.backward(loss) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) model.destroy() @@ -171,20 +176,20 @@ def train_integration(zero_stage: int = 1): if __name__ == "__main__": # for zero_stage in [1, 2, 3]: - zero_stage = 1 - baseline_not_trained, baseline_trained, baseline_outputs, baseline_data = train_baseline(zero_stage) - accelerator_not_trained, accelerator_trained, accelerator_outputs, accelerator_data = train_integration(zero_stage) - assert ( - baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] - ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' - assert ( - baseline_not_trained["f1"] == accelerator_not_trained["f1"] - ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' - assert ( - baseline_trained["accuracy"] == accelerator_trained["accuracy"] - ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' - assert ( - baseline_trained["f1"] == accelerator_trained["f1"] - ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' + for zero_stage in [3]: + baseline_not_trained, baseline_trained, baseline_outputs, baseline_data = train_baseline(zero_stage) + accelerator_not_trained, accelerator_trained, accelerator_outputs, accelerator_data = train_integration(zero_stage) + assert ( + baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] + ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' + assert ( + baseline_not_trained["f1"] == accelerator_not_trained["f1"] + ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' + assert ( + baseline_trained["accuracy"] == accelerator_trained["accuracy"] + ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' + assert ( + baseline_trained["f1"] == accelerator_trained["f1"] + ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' torch.distributed.destroy_process_group() From f8058764ae32935be0b19fa36454ca5e8ba634fb Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Fri, 17 Jan 2025 09:51:44 -0500 Subject: [PATCH 9/9] Bookmark: currently zero3 is underperforming --- benchmarks/fp8/torchao/distrib_deepspeed.py | 29 ++++++++++----------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/benchmarks/fp8/torchao/distrib_deepspeed.py b/benchmarks/fp8/torchao/distrib_deepspeed.py index d8019524a10..836238149a0 100644 --- a/benchmarks/fp8/torchao/distrib_deepspeed.py +++ b/benchmarks/fp8/torchao/distrib_deepspeed.py @@ -131,12 +131,10 @@ def train_baseline(zero_stage: int = 1): def train_integration(zero_stage: int = 1): set_seed(42) AcceleratorState()._reset_state(True) - # This forces transformers to think Zero-3 Init should be used - with patch("transformers.integrations.deepspeed.is_deepspeed_zero3_enabled") as mock: - mock.return_value = zero_stage == 3 deepspeed_plugin = DeepSpeedPlugin( zero_stage=zero_stage, zero3_init_flag=zero_stage == 3, + gradient_clipping=1.0, ) accelerator = Accelerator( mixed_precision="fp8", kwargs_handlers=[AORecipeKwargs()], deepspeed_plugin=deepspeed_plugin @@ -179,17 +177,18 @@ def train_integration(zero_stage: int = 1): for zero_stage in [3]: baseline_not_trained, baseline_trained, baseline_outputs, baseline_data = train_baseline(zero_stage) accelerator_not_trained, accelerator_trained, accelerator_outputs, accelerator_data = train_integration(zero_stage) - assert ( - baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] - ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' - assert ( - baseline_not_trained["f1"] == accelerator_not_trained["f1"] - ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' - assert ( - baseline_trained["accuracy"] == accelerator_trained["accuracy"] - ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' - assert ( - baseline_trained["f1"] == accelerator_trained["f1"] - ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' + print(baseline_trained, accelerator_trained) + # assert ( + # baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] + # ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' + # assert ( + # baseline_not_trained["f1"] == accelerator_not_trained["f1"] + # ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' + # assert ( + # baseline_trained["accuracy"] == accelerator_trained["accuracy"] + # ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' + # assert ( + # baseline_trained["f1"] == accelerator_trained["f1"] + # ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' torch.distributed.destroy_process_group()