-
Notifications
You must be signed in to change notification settings - Fork 431
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update Zephyr configs to account for UltraFeedback & TRL fixes #88
Conversation
@@ -4,34 +4,35 @@ model_name_or_path: alignment-handbook/zephyr-7b-sft-full | |||
# Data training arguments | |||
# For definitions, see: src/h4/training/config.py | |||
dataset_mixer: | |||
HuggingFaceH4/ultrafeedback_binarized: 1.0 | |||
HuggingFaceH4/ultrafeedback_binarized_fixed: 1.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace with the original source once we fix the dataset:
HuggingFaceH4/ultrafeedback_binarized_fixed: 1.0 | |
HuggingFaceH4/ultrafeedback_binarized: 1.0 |
|
||
# Data training arguments | ||
|
||
dataset_mixer: | ||
HuggingFaceH4/ultrafeedback_binarized: 1.0 | ||
HuggingFaceH4/ultrafeedback_binarized_fixed: 1.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
HuggingFaceH4/ultrafeedback_binarized_fixed: 1.0 | |
HuggingFaceH4/ultrafeedback_binarized: 1.0 |
torch_dtype: auto | ||
use_flash_attention_2: true | ||
model_revision: main | ||
torch_dtype: float16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it correct that training was done with float16
for the qlora training but bfloat16
for full parameter training? (And any reason for this, if so?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes that's correct - the main reason for these dtypes is that with 4-bit quantization, the other modules will be cast to float16
by default and I prefer to be explicit about this, while bfloat16
is needed for compatibility with FlashAttention2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @lewtun !
main reason for these dtypes is that with 4-bit quantization, the other modules will be cast to float16 by default
I assume you are referring to bnb_4bit_compute_dtype
being set to bfloat16
in get_quantization_config
. Is there merit in making this configurable?
Since the mistral 7b base is bfloat16
by default, would having a consistent type by also setting compute_dtype
to bfloat16
(and torch_dtype
to bfloat16
) have any benefit?
I'm no expert here (on the memory representation or on how BNB/peft work) - my assumption is just that since they technically have different dynamic ranges there may be some benefit to remaining across training steps (the pretrained base, sft, and dpo) with the compute dtype.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume you are referring to bnb_4bit_compute_dtype being set to bfloat16 in get_quantization_config. Is there merit in making this configurable?
Ah yes, I'm referring to this line and I think it would be good to set this as whatever the torch_dtype
is in model_args
(with float16
the default)
Since the mistral 7b base is bfloat16 by default, would having a consistent type by also setting compute_dtype to bfloat16 (and torch_dtype to bfloat16) have any benefit?
We haven't tested the effect of bfloat16
vs float16
with QLoRA, so once this PR is merged I can run a few experiments to test :)
) | ||
example["text_chosen"] = _strip_prefix(example["text_chosen"], assistant_prefix) | ||
example["text_rejected"] = _strip_prefix(example["text_rejected"], assistant_prefix) | ||
example["text_prompt"] = tokenizer.apply_chat_template(prompt_messages, tokenize=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, I noticed there are some inconsistencies here. _strip_prefix function and add_generation_prompt=True opinion are missing? Is this intended? and why is that, please?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this function was refactored to support multi-turn preference datasets and in the process I realised we could simplify the logic considerably by extracting the prompt & chosen / rejected responses directly from the list of dicts
instead of formatting the string with the chat template
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have any example for using multiturn data?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @lewtun, no comments from my side. LGTM
Since the release of
zephyr-7b-beta
there have been several important developments to the code and data used to train this model:trl
was not working correctly with packing which affects the way we train the initial SFT model for downstream optimisationzephyr-7b-beta
, but now is via theuse_reentrant:True
argGiven these changes, we've decided to do a full re-run of the Zephyr recipe to reconstruct a new set of hyperparameters that "just work" for full training and QLoRA.
The most notable changes include:
beta=0.01
gave better perf thanbeta=0.1
lora_r
andlora_alpha
hparam were tuned for best perf - it turns outlora_r=lora_alpha=16
is good for DPO irrespective of what values were used for SFTMT Bench Scores
There's some variability in MT-Bench, so treat these scores with a +/- 0.1 uncertainty:
Codebase changes
loss_type
to configTODO
ultrafeedback_binarized
to reload the fixedultrafeedback
dataset and filter out TruthfulQA samples (PR https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized/discussions/3)Closes #87 #85 #68 #61 #45 #72 #44 #24 #59