Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

# [BUG] [Fix-Suggested] Model Training Stalls with FSDP when fsdp_use_orig_params=False due to inconsistent model-optimizer state #3256

Closed
traincheck-team opened this issue Nov 24, 2024 · 7 comments

Comments

@traincheck-team
Copy link

Bug Description

Users have experienced model completely not learning when adapting their pipeline to FSDP (loss stays constant for each epoch), as reported in https://github.com/huggingface/accelerate/issues/2665.

Mitigation Setting fsdp_use_orig_params to true makes the model learning again.

We open a new issue here as the original has been closed and the root cause was not very clear there.

Environment

The bug is reproducible on the newest stable accelerate version. Below is the accelerate env output (the env probably does not matter as per the root cause we've diagnosed):

`accelerate env` (Click to Show)
Copy-and-paste the text below in your GitHub issue

- `Accelerate` version: 1.1.1
- Platform: Linux-5.15.0-113-generic-x86_64-with-glibc2.35
- `accelerate` bash location: /home/xxx/miniconda3/envs/fsdp_310/bin/accelerate
- Python version: 3.10.15
- Numpy version: 1.26.4
- PyTorch version (GPU?): 2.2.1+cu118 (True)
- PyTorch XPU available: False
- PyTorch NPU available: False
- PyTorch MLU available: False
- PyTorch MUSA available: False
- System RAM: 251.50 GB
- GPU type: NVIDIA A16
- `Accelerate` default config:
        Not found

To reproduce:

  1. Install all the dependencies (torch and accelerate)

    torch
    accelerate
    transformers 
    datasets
    evaluate
    nvidia_smi
    nvidia-ml-py3
    nltk
    peft
    absl-py
    rouge_score
  2. Run bug.py using run.sh (same as reported in FSDP Model not learning during training, loss stays constant #2665)

    run.sh
    export CUDA_VISIBLE_DEVICES=0
    GPUS=1
    CURRENT_DATE_TIME=$(date '+%Y-%m-%d___%H-%M-%S')
    export CURRENT_DATE_TIME
    TORCH_DISTRIBUTED_DEBUG=DETAIL
    export TORCH_DISTRIBUTED_DEBUG
    
    # Define Environment Variables - use these to alter training scripts below.
    FILE_NAME='train.py'
    USER_PROFILE='userName'
    OUTPUT_DIR='myModel'
    MODEL_AND_TOKENIZER_NAME='google-t5/t5-small'
    REPO_NAME=$USER_PROFILE/$OUTPUT_DIR
    FINETUNE='Yes' # not used
    BRANCH_NAME='main'
    PRECISION='no' # choices ["no", "fp16", "bf16", "fp8"]
    WITH_TRACKING='yes'
    ACCUMULATION_STEPS=2
    LR=0.00002
    WEIGHT_DECAY=0.0
    TRAIN_BATCH_SIZE=1
    VAL_BATCH_SIZE=1
    MAX_INPUT_LENGTH=1024
    MAX_TARGET_LENGTH=128
    EPOCHS=2
    LOG_DIR='logs'
    LOGS=$LOG_DIR/$OUTPUT_DIR/$CURRENT_DATE_TIME
    CONFIG='/home/runhui/machine-learning-issues/AC2665/default.yaml'
    CHECKPOINT_ITERATIONS='epoch'
    mkdir -p $LOGS
    export CONFIG
    
    accelerate launch --config_file $CONFIG \
    $FILE_NAME \
    --outputDir "$OUTPUT_DIR" \
    --modelAndTokenizerName "$MODEL_AND_TOKENIZER_NAME" \
    --repoName "$REPO_NAME" \
    --finetune "$FINETUNE" \
    --existsBranch "$BRANCH_NAME" \
    --mixed_precision "$PRECISION" \
    --checkpointing_steps "$CHECKPOINT_ITERATIONS" \
    --logging_dir "$LOGS" \
    --train_batch_size $TRAIN_BATCH_SIZE \
    --val_batch_size $VAL_BATCH_SIZE \
    --max_input_length $MAX_INPUT_LENGTH \
    --max_target_length $MAX_TARGET_LENGTH \
    --gradient_accumulation_steps $ACCUMULATION_STEPS \
    --learning_rate $LR \
    --weight_decay $WEIGHT_DECAY \
    --train_epochs $EPOCHS \
    --with_tracking \
    2>&1 | tee $LOGS/output.log
    default.yml
    compute_environment: LOCAL_MACHINE
    debug: false
    distributed_type: FSDP
    downcast_bf16: 'no'
    fsdp_config:
    fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
    fsdp_backward_prefetch: NO_PREFETCH
    fsdp_cpu_ram_efficient_loading: true
    fsdp_forward_prefetch: false
    fsdp_offload_params: true # whether to offload parameters and gradients to CPU. Offloading state dict and state dict's tensor values is done in FSDPPlugin FullStateDictConfig and OptimStateDictConfig, respectively.
    fsdp_sharding_strategy: FULL_SHARD
    fsdp_state_dict_type: SHARDED_STATE_DICT
    fsdp_sync_module_states: true
    fsdp_transformer_layer_cls_to_wrap: T5Block
    fsdp_use_orig_params: false
    machine_rank: 0
    main_training_function: main
    mixed_precision: 'no'
    num_machines: 1
    num_processes: 1
    rdzv_backend: static
    same_network: true
    tpu_env: []
    tpu_use_cluster: false
    tpu_use_sudo: false
    use_cpu: false
    train.py
    import argparse
    import gc
    import threading
    import psutil
    import numpy as np
    import pandas
    import evaluate
    import sys
    import os
    import time
    import logging
    import math
    import json
    import nvidia_smi
    from pynvml import *
    from tqdm.auto import tqdm
    
    import torch
    from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig
    from torch.optim import AdamW
    from torch.utils.data.distributed import DistributedSampler
    from torch.utils.data import DataLoader
    
    from accelerate import Accelerator, DistributedType, DistributedDataParallelKwargs, FullyShardedDataParallelPlugin
    from accelerate.logging import get_logger
    from accelerate.tracking import TensorBoardTracker
    from accelerate.utils import is_npu_available, is_xpu_available, GradientAccumulationPlugin
    
    import transformers
    from transformers import AutoTokenizer, AutoConfig
    from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq
    from transformers import get_scheduler
    from transformers import set_seed
    from transformers import Adafactor
    from transformers.utils import send_example_telemetry
    from huggingface_hub import Repository, create_repo, get_full_repo_name
    
    import datasets
    from datasets import load_dataset
    
    import nltk
    from nltk.tokenize import sent_tokenize
    
    from peft import LoraConfig, TaskType
    from peft import get_peft_model
    import accelerate.optimizer
    from torch.optim.adamw import AdamW
    
    def monitor_gpuTemp(handle0, handle1):
        temp0 = nvidia_smi.nvmlDeviceGetTemperature(handle0, nvidia_smi.NVML_TEMPERATURE_GPU)
        temp1 = nvidia_smi.nvmlDeviceGetTemperature(handle1, nvidia_smi.NVML_TEMPERATURE_GPU)
        if temp0 >= 75 or temp1 >= 75: 
            print(f"\nGPU 0 Temperature: {temp0}C")
            print(f"\nGPU 1 Temperature: {temp1}C")
            raise RuntimeError("GPU Temperature is too high!")
        
    def postprocess_text(preds, labels):
        preds = [pred.strip() for pred in preds]
        labels = [label.strip() for label in labels]
    
        # ROUGE expects a newline after each sentence
        preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
        labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
    
        return preds, labels
    
    
    def training_function(config, args):
        # For GPU temp monitoring
        nvidia_smi.nvmlInit()
        handle0 = nvidia_smi.nvmlDeviceGetHandleByIndex(0)
        handle1 = nvidia_smi.nvmlDeviceGetHandleByIndex(1)
    
        dateAndTime = os.environ['CURRENT_DATE_TIME']
        nltk.download("punkt")
    
        # Output directory
        output_dir = args.outputDir
    
        # Pass the advanced FSDP settings not part of the accelerate config by creating fsdp_plugin
        fsdp_plugin = FullyShardedDataParallelPlugin(
            state_dict_config=FullStateDictConfig(offload_to_cpu=False, rank0_only=True), # can set these to true to provide more GPU memory at the cost of computation time
            optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=False, rank0_only=True), # can set these to true to provide more GPU memory at the cost of computation time
        )
            #activation_checkpointing=True, # provides more GPU memory at the cost of computation time
        #)
    
        # Initialize accelerator and gradient accumulation plugin
        grad_accum_plugin = GradientAccumulationPlugin(num_steps=args.gradient_accumulation_steps, sync_with_dataloader=False)
        if args.with_tracking:
            accelerator = Accelerator(
                cpu=False,
                mixed_precision=args.mixed_precision,
                log_with="tensorboard",
                project_dir=args.logging_dir,
                gradient_accumulation_plugin=grad_accum_plugin,
                fsdp_plugin=fsdp_plugin
            )
        else:
            accelerator = Accelerator(gradient_accumulation_plugin=grad_accum_plugin, fsdp_plugin=fsdp_plugin)
        accelerator.print(accelerator.distributed_type)
    
        @accelerator.on_main_process
        def training_log(epoch, num_epoch, i_iter, epoch_iters, optimizer, loss):
            msg = '\nEpoch: [{}/{}] Iter:[{}/{}], lr: {}, Loss: {:.6f}'.format(
                epoch, num_epoch, i_iter, epoch_iters,
                [x['lr'] for x in optimizer.param_groups], loss)
            print(msg)
    
        @accelerator.on_main_process
        def print_rouge(epoch, result):
            print(f"Epoch {epoch}:", result)
    
    
        if hasattr(args.checkpointing_steps, "isdigit"):
            if args.checkpointing_steps == "epoch":
                checkpointing_steps = args.checkpointing_steps
            elif args.checkpointing_steps.isdigit():
                checkpointing_steps = int(args.checkpointing_steps)
            else:
                raise ValueError(
                    f"Argument `checkpointing_steps` must be either a number or `epoch`. `{args.checkpointing_steps}` passed."
                )
        else:
            checkpointing_steps = None
        # Sample hyper-parameters for learning rate, batch size, seed and a few other HPs
        lr = config["lr"]
        num_epochs = int(args.train_epochs)
        seed = int(config["seed"])
        batch_size = int(config["batch_size"])
    
        # We need to initialize the trackers we use, and also store our configuration
        if args.with_tracking:
            experiment_config = vars(args)
            accelerator.init_trackers("fsdp_pubMed_no_trainer", experiment_config)
    
        tokenizer = AutoTokenizer.from_pretrained(args.modelAndTokenizerName)
        raw_train_dataset = load_dataset("ccdv/pubmed-summarization", "document", split="train[:1%]")
        raw_val_dataset = load_dataset("ccdv/pubmed-summarization", "document", split="validation[:1%]")
        column_names = raw_train_dataset.column_names
        metric = evaluate.load("rouge")
    
        # Define tokenizer pre-processing function for the dataset
        max_input_length = args.max_input_length # this defines the maximum number of tokens the model can take as input for any given task.
        max_target_length = args.max_target_length
        padding = "max_length"
        truncation = "longest_first"
        def tokenize_function(examples):
    
            model_inputs = tokenizer(examples["article"], max_length=max_input_length, padding=padding, truncation=True)
            labels = tokenizer(text_target=examples["abstract"], max_length=max_target_length, padding=padding, truncation=True)
    
            # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
            # padding in the loss.
            if padding == "max_length" and args.ignore_pad_token_for_loss:
                labels["input_ids"] = [
                    [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
                ]
    
            model_inputs["labels"] = labels["input_ids"]
            return model_inputs
    
        # Apply the method we just defined to all the examples in all the splits of the dataset
        # starting with the main process first:
        accelerator.wait_for_everyone()
        with accelerator.main_process_first():
            train_dataset = raw_train_dataset.map(
                tokenize_function, 
                batched=True,
                num_proc=6,
                remove_columns=column_names,
                load_from_cache_file=True,
                desc="Running tokenizer on raw train split"
                )
            val_dataset = raw_val_dataset.map(
                tokenize_function,
                batched=True,
                num_proc=6,
                remove_columns=column_names,
                load_from_cache_file=True,
                desc="Running tokenizer on raw val split"
                )
            train_dataset.set_format("torch")
            train_dataset = train_dataset.select(range(10))
            val_dataset.set_format("torch")
            val_dataset = val_dataset.select(range(10))
    
        # If the batch size is too big we use gradient accumulation
        gradient_accumulation_steps = args.gradient_accumulation_steps
        train_batch_size = args.train_batch_size
        val_batch_size = args.val_batch_size
    
        set_seed(seed)
    
        # Instantiate the model (we build the model here so that the seed also control new weights initialization)
        autoConfig = AutoConfig.from_pretrained(args.modelAndTokenizerName)
        model = AutoModelForSeq2SeqLM.from_pretrained(args.modelAndTokenizerName, config=autoConfig)
        #model.gradient_checkpointing_enable() # reduces memory usage during training
    
        label_pad_token_id = -100 if args.ignore_pad_token_for_loss else tokenizer.pad_token_id
        data_collator = DataCollatorForSeq2Seq(
            tokenizer=tokenizer, 
            model=model,
            label_pad_token_id=label_pad_token_id,
            pad_to_multiple_of=8 if accelerator.mixed_precision == "fp16" else None,
            # pad_to_multiple_of=8 if accelerator.use_fp16" else None,
        )
        # print("mixed precision:",accelerator.mixed_precision)
        # Instantiate dataloaders.
        train_dataloader = DataLoader(
            train_dataset,
            pin_memory=True, # for speeding up training set = true, enables faster transfers between CPU and GPU memory
            shuffle=True,
            collate_fn=data_collator,
            batch_size=train_batch_size,
            num_workers=4 # for speeding up training, spawns several workers to preload the data faster. If GPU utilization is far from 100%, increase number of workers.
        )
    
        val_dataloader = DataLoader(
            val_dataset,
            pin_memory=True,
            collate_fn=data_collator,
            batch_size=val_batch_size,
            num_workers=4
        )
    
        # model= accelerator.prepare(
        #     model
        # )
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": 0.003,
            },
            {
                "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]
    
        #optimizer = Adafactor(optimizer_grouped_parameters, lr=args.learning_rate, weight_decay=args.weight_decay)
        optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
    
        # Training loop updates
        num_train_epochs = args.train_epochs
        num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
        max_training_steps = num_train_epochs * num_update_steps_per_epoch
    
        # Instantiate scheduler
        lr_scheduler = get_scheduler(
            name="linear",
            optimizer=optimizer,
            num_warmup_steps=0,
            num_training_steps=max_training_steps,
        )
    
        # Validation loop updates
        num_val_epochs = args.train_epochs
        num_update_steps_per_epoch_val = math.ceil(len(val_dataloader) / args.gradient_accumulation_steps)
        max_validation_steps = num_val_epochs * num_update_steps_per_epoch_val
    
        # Prepare accelerator
        # optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
        #     optimizer, train_dataloader, val_dataloader, lr_scheduler
        # )
        model, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
            model, optimizer, train_dataloader, val_dataloader, lr_scheduler
        )
        print("optimizer type:",type(optimizer))
        # Recalculate training loop updates because it changes after prepare method sometimes
        num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
        max_training_steps = num_train_epochs * num_update_steps_per_epoch
        num_train_epochs = math.ceil(max_training_steps / num_update_steps_per_epoch)
        progress_bar = tqdm(range(max_training_steps), disable=not accelerator.is_local_main_process)
    
        # Recalculate validation loop updates because it changes after prepare method
        num_update_steps_per_epoch_val = math.ceil(len(val_dataloader) / args.gradient_accumulation_steps)
        max_validation_steps = num_val_epochs * num_update_steps_per_epoch_val
        num_val_epochs = math.ceil(max_validation_steps / num_update_steps_per_epoch_val)
        val_progress_bar = tqdm(range(max_validation_steps), disable=not accelerator.is_local_main_process)
    
        completed_steps = 0
        val_completed_steps = 0
        progress_bar.update(completed_steps)
        val_progress_bar.update(val_completed_steps)
    
    
        # Potentially load in the weights and states from a previous save
        if args.resume_from_checkpoint:
            if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
                accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
                accelerator.load_state(args.resume_from_checkpoint)
                path = os.path.basename(args.resume_from_checkpoint)
            else:
                # Get the most recent checkpoint
                dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
                dirs.sort(key=os.path.getctime)
                path = dirs[-1]  # Sorts folders by date modified, most recent checkpoint is the last
            # Extract `epoch_{i}` or `step_{i}`
            training_difference = os.path.splitext(path)[0]
    
            if "epoch" in training_difference:
                num_epochs -= int(training_difference.replace("epoch_", ""))
                resume_step = None
            else:
                resume_step = int(training_difference.replace("step_", ""))
                num_epochs -= resume_step // len(train_dataloader)
                # If resuming by step, we also need to know exactly how far into the DataLoader we went
                resume_step = (num_epochs * len(train_dataloader)) - resume_step
    
        # Now we train the model
        for epoch in range(num_train_epochs):
            monitor_gpuTemp(handle0, handle1)
            completed_steps=0
            progress_bar.update(completed_steps)
            # Training
            model.train()
            if args.with_tracking:
                total_loss = 0
            for step, batch in enumerate(train_dataloader):
                monitor_gpuTemp(handle0, handle1)
                with accelerator.accumulate(model):
                    # We could avoid this line since we set the accelerator with `device_placement=True`.
                    batch.to(accelerator.device)
                    outputs = model(**batch)
                    loss = outputs.loss
                    # We keep track of the loss at each epoch
                    if args.with_tracking:
                        total_loss += loss.detach().float()
                    accelerator.backward(loss)
                    # tracer.DISABLE_WRAPPER = False
                    optimizer.step()
                    lr_scheduler.step() 
                    optimizer.zero_grad()
                    # tracer.DISABLE_WRAPPER = True  
                # Check if the accelerator has performed an optimization step behind the scenes
                if accelerator.sync_gradients:
                    progress_bar.update(1)
                    completed_steps += 1
    
                # Save checkpoint on specific iteration when using StateDictType is SHARDED_STATE_DICT
                if isinstance(checkpointing_steps, int):
                    if completed_steps % checkpointing_steps == 0:
                        ckpt_step_dir = f"{dateAndTime}/accel_state_ckpt_step_{completed_steps}"
                        if output_dir is not None:
                            accel_state_dir = os.path.join(output_dir, ckpt_step_dir)
                        accelerator.save_state(accel_state_dir)
    
                accelerator.wait_for_everyone()
                training_log(epoch, num_train_epochs, completed_steps, max_training_steps, optimizer, loss)
                accelerator.wait_for_everyone()
    
                if completed_steps >= 50:# max_training_steps: # changed this for quicker results
                    break
                    
    
            val_completed_steps=0
            val_progress_bar.update(val_completed_steps)
            model.eval()
            for step, batch in enumerate(val_dataloader):
                monitor_gpuTemp(handle0, handle1)
                batch.to(accelerator.device)
                with torch.no_grad():
                    outputs = model(**batch)
                    loss = outputs.loss
                    predictions = outputs.logits.argmax(dim=-1)
                    accelerator.wait_for_everyone()
                    predictions, targets = accelerator.gather_for_metrics((predictions, batch["labels"]))
    
                    # Send to cpu for conversion to numpy
                    predictions = predictions.cpu().numpy()
                    targets = targets.cpu().numpy()
                    # Replace -100 in the references (targets) since we can't decode them
                    targets = np.where(targets != -100, targets, tokenizer.pad_token_id)
                    if isinstance(predictions, tuple):
                        predictions = predictions[0]
                    decoded_preds = tokenizer.batch_decode(
                        predictions, skip_special_tokens=True
                    )
                    decoded_targets = tokenizer.batch_decode(targets, skip_special_tokens=True)
    
                    decoded_preds, decoded_targets = postprocess_text(
                        decoded_preds, decoded_targets
                    )
    
                    metric.add_batch(predictions=decoded_preds, references=decoded_targets)
                    if accelerator.sync_gradients:
                        val_progress_bar.update(1)
                        val_completed_steps += 1 
                    training_log(epoch, num_epochs, val_completed_steps, max_validation_steps, optimizer, loss)
                    # Specify how many iterations contribute towards validation rouge scores metric
                    if val_completed_steps >= 25:# max_validation_steps: # changed this for quicker results
                        break
    
            # Compute metrics
            result = metric.compute(use_stemmer=True)
    
            # Extract the median ROUGE scores
            result = {k: round(v * 100, 4) for k, v in result.items()}
            accelerator.wait_for_everyone()
            print_rouge(epoch, result)
            accelerator.wait_for_everyone()
    
            if args.with_tracking:
                result["train_loss"] = total_loss.item() / len(train_dataloader)
                result["epoch"] = epoch
                result["step"] = completed_steps
                accelerator.log(result, step=completed_steps)   
                    
            # Save and upload
            if epoch < num_train_epochs:
                accelerator.wait_for_everyone()
                if accelerator.is_local_main_process:
                    print(f"Saving checkpoint for epoch {epoch+1}")
                checkpoint_dir = f"{output_dir}/{dateAndTime}/checkpoints/epoch-{epoch+1}"
                os.makedirs(checkpoint_dir, exist_ok=True)
                accelerator.wait_for_everyone() 
                unwrapped_model = accelerator.unwrap_model(model)
                unwrapped_model.save_pretrained(checkpoint_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save, state_dict=accelerator.get_state_dict(model))
                # Also save to default directory
                unwrapped_model.save_pretrained(output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save, state_dict=accelerator.get_state_dict(model))
                if accelerator.is_main_process:
                    tokenizer.save_pretrained(checkpoint_dir)
                    tokenizer.save_pretrained(output_dir)
                    # Function that adds, commits
                    #repo.push_to_hub(
                    #commit_message=f"Model {args.modelAndTokenizerName} checkpoint from epoch {epoch}", blocking=True
                    #)
    
        if output_dir is not None:
            accelerator.wait_for_everyone() 
            if accelerator.state.fsdp_plugin is not None:
                accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
            accelerator.wait_for_everyone()
            unwrapped_model = accelerator.unwrap_model(model)
            unwrapped_model.save_pretrained(checkpoint_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save)
            # Also save to default directory
            unwrapped_model.save_pretrained(output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save, state_dict=accelerator.get_state_dict(model))
            if accelerator.is_main_process:
                tokenizer.save_pretrained(checkpoint_dir)
                tokenizer.save_pretrained(output_dir)
                # Function that adds, commits
                #repo.push_to_hub(
                #commit_message=f"Model {args.modelAndTokenizerName} checkpoint from epoch {epoch}", blocking=True
                #)
    
                all_results = {f"eval_{k}": v for k, v in result.items()}
                with open(os.path.join(output_dir, "all_results.json"), "w") as f:
                    json.dump(all_results, f)
    
        nvidia_smi.nvmlShutdown()
        
    
        if args.with_tracking:
            accelerator.end_training()
        myDir = f"{output_dir}/{dateAndTime}/checkpoints/"
        return myDir
    
    
    def main():
        parser = argparse.ArgumentParser(description="Simple example of training script.")
        parser.add_argument(
            "--mixed_precision",
            type=str,
            default=None,
            choices=["no", "fp16", "bf16", "fp8"],
            help="Whether to use mixed precision. Choose"
            "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
            "and an Nvidia Ampere GPU.",
        )
        parser.add_argument("--cpu", action="store_true", help="If passed, will train on the CPU.")
        parser.add_argument(
            "--checkpointing_steps",
            type=str,
            default=None,
            help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
        )
        parser.add_argument(
            "--resume_from_checkpoint",
            type=str,
            default=None,
            help="If the training should continue from a checkpoint folder.",
        )
        parser.add_argument(
            "--with_tracking",
            action="store_true",
            help="Whether to load in all available experiment trackers from the environment and use them for logging.",
        )
        parser.add_argument(
            "--logging_dir",
            type=str,
            default="logs",
            help="Location on where to store experiment tracking logs`",
        )
        parser.add_argument(
            "--gradient_accumulation_steps", 
            type=int, 
            required=False,
        )
        parser.add_argument(
            "--train_batch_size", 
            type=int, 
            required=True,
        )
        parser.add_argument(
            "--val_batch_size", 
            type=int, 
            required=True
        )
        parser.add_argument(
            "--learning_rate", 
            type=float, 
            required=True
        )
        parser.add_argument(
            "--weight_decay", 
            type=float, 
            required=True
        )
        parser.add_argument(
            "--train_epochs",
            type=int,
            required=True
        )
        parser.add_argument(
            "--outputDir",
            type=str,
            default=".",
            help="Optional save directory where all checkpoint folders will be stored. Default is the current working directory.",
        )
        parser.add_argument(
            "--repoName", 
            type=str, 
            required=True
        )
        parser.add_argument(
            "--finetune", 
            type=str, 
            required=True
        )
        parser.add_argument(
            "--existsBranch",
            type=str, 
            required=False
        )
        parser.add_argument(
            "--newBranch", 
            type=str, 
            required=False
        )
        parser.add_argument(
            "--modelAndTokenizerName",
            type=str,
            help="Path to pretrained model or model identifier from huggingface.co/models.",
            required=True,
        )
        parser.add_argument(
            "--ignore_pad_token_for_loss",
            type=bool,
            default=True,
            help="Whether to ignore the tokens corresponding to padded labels in the loss computation or not.",
        )
        parser.add_argument(
            "--max_input_length",
            type=int,
            help="Maximum input sequence length for inference",
            required=True,
        )
        parser.add_argument(
            "--max_target_length",
            type=int,
            help="Maximum output target sequence length for prediction",
            required=True,
        )
    
        args = parser.parse_args()
        config = {"lr": args.learning_rate, "num_epochs": args.train_epochs, "seed": 42, "batch_size": args.train_batch_size}
        myDir = training_function(config, args)
        input_prompt = "Type a very long input prompt..."
        for folder in sorted(os.listdir(myDir)):
            ckpt = os.path.join(myDir, folder)
            hub_model_id = ckpt
            tokenizer = AutoTokenizer.from_pretrained(hub_model_id)
            input_ids = tokenizer.encode(input_prompt, return_tensors="pt")
            model = AutoModelForSeq2SeqLM.from_pretrained(hub_model_id)
    
            output = model.generate(
            input_ids,
            max_length=512,  # Generate up to 512 tokens
            min_length=128,      # Ensure the total length is at least 30 tokens
            length_penalty=1.0, # No length penalty
            no_repeat_ngram_size=2, # Prevent repeating n-grams of size 2
            num_beams=2,
            early_stopping=True # Stop when the first beam hypothesis reaches EOS
        )
            print(f'{folder}',tokenizer.decode(output[0], skip_special_tokens=True))
        
        print(f"\nTraining Finished\n")
    
    
    
    if __name__ == "__main__":
        main()
  3. The above scripts will run 2 epochs and 10 steps for each epoch with fsdp_use_orig_params == False

  4. Notice how for each epoch, the validation scores are the same and loss information is the same for the same batch.

  5. Modify default.yml to have fsdp_use_orig_params: true and run step 2 again. Observe that the model now learn correctly as reported by FSDP Model not learning during training, loss stays constant #2665 (comment)

Issue Root Cause

FSDP(model) flattens model parameters and use the flattened ones for training, as can be seen from pytorch codebase (e.g. https://github.com/pytorch/pytorch/blob/6a096a0b960b415e95b89efb6cc6eeaa9c0f48ab/torch/distributed/fsdp/_unshard_param_utils.py#L122).

In the buggy pipeline train.py, the optimizer AdamW was initialized before wrapping model with FSDP which is done in accelerator.prepare(model, optimizer, train_dataloader, val_dataloader, lr_scheduler). This optimizer only sees the original, unflattened model parameters, since accelerator.prepare does not update optimizer if fsdp_use_orig_params == False

Additional Evidence
We ran our bug detection tool against the problematic code's runtime trace and noticed two things during runtime:

  1. the step api of the optimizer was not updating the model (neither model parameter change nor any computation op got invoked )
  2. the zero_grad api of the optimizer was not doing anything (all grad before entering zero_grad were already None).

Possible User-side Workaround

  1. Set fsdp_use_orig_params to True

or

  1. Initialize optimizer after model is wrapped.
    i.e.
    change

    # buggy
    optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
    
    model, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
            model, optimizer, train_dataloader, val_dataloader, lr_scheduler
        )

    to

    # correct
    model = accelerator.prepare(model)
    
    # initialize optimizer using the updated params
    optimizer_grouped_parameters = xxx
    optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
    
    optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, val_dataloader, lr_scheduler)

Suggested accelerate-side Fix:

  1. Emit a warning if observed that optimizer and model parameters do not overlap
  2. Update optimizer automatically if observed fsdp_use_orig_params is True

We will be more than happy to provide a PR for this issue! Let me know how you want to proceed!

@BenjaminBossan
Copy link
Member

BenjaminBossan commented Nov 25, 2024

Thanks for the detailed report. I think it could also be the same issue as #3209 (and check #3213 for a proposed solution).

@traincheck-team
Copy link
Author

Yes, it is the same issue! The proposed solution is reasonable. Thanks.

@BenjaminBossan
Copy link
Member

Can you check if huggingface/transformers#35212 has solved the issue? If not, could you try if additionally switching off flash attention helps.

@traincheck-team
Copy link
Author

traincheck-team commented Dec 21, 2024

@BenjaminBossan Thanks! I have tried huggingface/transformers#35212 and switched to this commit

commit 7237b3ecfc65c0dbf62a330e47cd8deebc27428c (HEAD)
Author: Zach Mueller <[email protected]>
Date:   Fri Dec 13 13:20:51 2024 -0500

    Fix FSDP no longer working (#35212)
    
    Fix FSDP failing

The problem persists and I do not have flash attn installed in my env.

@traincheck-team
Copy link
Author

I think huggingface/transformers#35212 helps when the users do not supply a custom optimizer and instead relying on the trainer to create the optimizer.

If the user actually initialized optimizers themselves, the trainer would just respect whatever the user has done (see https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py#L1186-L1201).

@BenjaminBossan
Copy link
Member

Thanks for the feedback @traincheck-team. Your link appears to be out of date (tip: use perma links) but I think it's clear what you mean. When the user creates the optimizer, I think it is reasonable to honor that and not re-initialize the optimizer. In this case, it clashes with the need to do delayed initialization. Honestly, I don't have a good idea how to consolidate the two needs. Maybe @muellerzr has an idea?

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this as completed Feb 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants