Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Zephyr configs to account for UltraFeedback & TRL fixes #88

Merged
merged 29 commits into from
Jan 10, 2024

Conversation

lewtun
Copy link
Member

@lewtun lewtun commented Jan 4, 2024

Since the release of zephyr-7b-beta there have been several important developments to the code and data used to train this model:

  • UltraFeedback was fixed to correct ~few thousand incorrect labels and the community pointed out it is better to filter out the TruthfulQA subset to avoid contamination with the OpenLLM leaderboard
  • The learning rate scheduler in trl was not working correctly with packing which affects the way we train the initial SFT model for downstream optimisation
  • Gradient checkpointing was not available for DPO when we trained zephyr-7b-beta, but now is via the use_reentrant:True arg

Given these changes, we've decided to do a full re-run of the Zephyr recipe to reconstruct a new set of hyperparameters that "just work" for full training and QLoRA.

The most notable changes include:

  • Promoting QLoRA as the main alternative to full training. One issue with LoRA + ZeRO-3 is that it's not possible to load adapters on sharded models (e.g. you can't easily load the SFT and DPO adapters first and then shard the resulting model). Given that Zephyr is just a 7B model, QLoRA works great with DDP and is the simpler alternative to promote.
  • beta=0.01 gave better perf than beta=0.1
  • The global batch size was tuned for best perf (it turns out smaller batch sizes tend to work better for QLoRA). tl;dr we use GBS=128 for SFT/DPO with full-training and GBS=64/32 for SFT/DPO with QLoRA
  • The lora_r and lora_alpha hparam were tuned for best perf - it turns out lora_r=lora_alpha=16 is good for DPO irrespective of what values were used for SFT
  • Reducing the number of DPO epochs to 1 gave similar perf as the original Zephyr model, while being more compute efficient.
  • Using AdamW and a cosine scheduler gave better perf in DPO

MT Bench Scores

There's some variability in MT-Bench, so treat these scores with a +/- 0.1 uncertainty:

Model MT-Bench Score
alignment-handbook/zephyr-7b-sft-full 6.350
alignment-handbook/zephyr-7b-dpo-full 7.403
alignment-handbook/zephyr-7b-sft-qlora 6.484
alignment-handbook/zephyr-7b-dpo-qlora 7.544

Codebase changes

  • The formatting of dialogues for DPO has been extended to multi-turn contexts
  • Added helper function to enable intermediate checkpoint loading for failed runs
  • Add DPO loss_type to config
  • Fixed upload hanging with DeepSpeed when pushing checkpoint to Hub

TODO

Closes #87 #85 #68 #61 #45 #72 #44 #24 #59

@lewtun lewtun requested a review from edbeeching January 4, 2024 22:24
@@ -4,34 +4,35 @@ model_name_or_path: alignment-handbook/zephyr-7b-sft-full
# Data training arguments
# For definitions, see: src/h4/training/config.py
dataset_mixer:
HuggingFaceH4/ultrafeedback_binarized: 1.0
HuggingFaceH4/ultrafeedback_binarized_fixed: 1.0
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace with the original source once we fix the dataset:

Suggested change
HuggingFaceH4/ultrafeedback_binarized_fixed: 1.0
HuggingFaceH4/ultrafeedback_binarized: 1.0


# Data training arguments

dataset_mixer:
HuggingFaceH4/ultrafeedback_binarized: 1.0
HuggingFaceH4/ultrafeedback_binarized_fixed: 1.0
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
HuggingFaceH4/ultrafeedback_binarized_fixed: 1.0
HuggingFaceH4/ultrafeedback_binarized: 1.0

torch_dtype: auto
use_flash_attention_2: true
model_revision: main
torch_dtype: float16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it correct that training was done with float16 for the qlora training but bfloat16 for full parameter training? (And any reason for this, if so?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that's correct - the main reason for these dtypes is that with 4-bit quantization, the other modules will be cast to float16 by default and I prefer to be explicit about this, while bfloat16 is needed for compatibility with FlashAttention2

Copy link
Contributor

@nathan-az nathan-az Jan 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @lewtun !

main reason for these dtypes is that with 4-bit quantization, the other modules will be cast to float16 by default

I assume you are referring to bnb_4bit_compute_dtype being set to bfloat16 in get_quantization_config. Is there merit in making this configurable?

Since the mistral 7b base is bfloat16 by default, would having a consistent type by also setting compute_dtype to bfloat16 (and torch_dtype to bfloat16) have any benefit?

I'm no expert here (on the memory representation or on how BNB/peft work) - my assumption is just that since they technically have different dynamic ranges there may be some benefit to remaining across training steps (the pretrained base, sft, and dpo) with the compute dtype.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume you are referring to bnb_4bit_compute_dtype being set to bfloat16 in get_quantization_config. Is there merit in making this configurable?

Ah yes, I'm referring to this line and I think it would be good to set this as whatever the torch_dtype is in model_args (with float16 the default)

Since the mistral 7b base is bfloat16 by default, would having a consistent type by also setting compute_dtype to bfloat16 (and torch_dtype to bfloat16) have any benefit?

We haven't tested the effect of bfloat16 vs float16 with QLoRA, so once this PR is merged I can run a few experiments to test :)

)
example["text_chosen"] = _strip_prefix(example["text_chosen"], assistant_prefix)
example["text_rejected"] = _strip_prefix(example["text_rejected"], assistant_prefix)
example["text_prompt"] = tokenizer.apply_chat_template(prompt_messages, tokenize=False)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, I noticed there are some inconsistencies here. _strip_prefix function and add_generation_prompt=True opinion are missing? Is this intended? and why is that, please?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this function was refactored to support multi-turn preference datasets and in the process I realised we could simplify the logic considerably by extracting the prompt & chosen / rejected responses directly from the list of dicts instead of formatting the string with the chat template

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have any example for using multiturn data?

@lewtun lewtun marked this pull request as ready for review January 8, 2024 06:53
@lewtun lewtun changed the title [WIP] Update Zephyr configs to account for UltraFeedback & TRL fixes Update Zephyr configs to account for UltraFeedback & TRL fixes Jan 8, 2024
Copy link
Contributor

@edbeeching edbeeching left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @lewtun, no comments from my side. LGTM

@lewtun lewtun merged commit f0ffa0d into main Jan 10, 2024
3 checks passed
@lewtun lewtun deleted the zephyr-repro branch January 10, 2024 06:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

How can I config loss_type?
5 participants