From 88f0bb1f70995ae991a5e4dff0e00e4d2e88d9f2 Mon Sep 17 00:00:00 2001 From: lyxok1 Date: Wed, 6 Dec 2023 00:01:41 +0800 Subject: [PATCH 1/2] update to save zero3 ckpt --- chatllms/utils/model_utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/chatllms/utils/model_utils.py b/chatllms/utils/model_utils.py index a053f5d..8f4ad47 100644 --- a/chatllms/utils/model_utils.py +++ b/chatllms/utils/model_utils.py @@ -11,6 +11,7 @@ from transformers.trainer_utils import get_last_checkpoint from transformers.generation.logits_process import LogitsProcessor from transformers.generation.utils import LogitsProcessorList +from transformers.deepspeed import is_deepspeed_zero3_enabled from chatllms.data.data_utils import (DEFAULT_BOS_TOKEN, DEFAULT_EOS_TOKEN, DEFAULT_PAD_TOKEN, DEFAULT_UNK_TOKEN) @@ -311,10 +312,14 @@ def get_logits_processor() -> LogitsProcessorList: def safe_save_model_for_hf_trainer(trainer: Trainer, output_dir: str): """Collects the state dict and dump to disk.""" state_dict = trainer.model.state_dict() - if trainer.args.should_save: + if not is_deepspeed_zero3_enabled() and trainer.args.should_save: cpu_state_dict = { key: value.cpu() for key, value in state_dict.items() } del state_dict trainer._save(output_dir, state_dict=cpu_state_dict) # noqa + elif is_deepspeed_zero3_enabled(): + # save for deepspeed ZeRO3 checkpoint + if not trainer.wrapped_model.save_16bit_model(output_dir): + trainer.wrapped_model.save_checkpoint(output_dir, save_latest=True) \ No newline at end of file From bccb6adf77398df54e09b8418fafe00b6d983b8e Mon Sep 17 00:00:00 2001 From: lyxok1 Date: Sat, 9 Dec 2023 23:47:05 +0800 Subject: [PATCH 2/2] fix zero3 saving issue; in compatible with transformers --- chatllms/utils/model_utils.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/chatllms/utils/model_utils.py b/chatllms/utils/model_utils.py index 8f4ad47..80853d0 100644 --- a/chatllms/utils/model_utils.py +++ b/chatllms/utils/model_utils.py @@ -309,17 +309,20 @@ def get_logits_processor() -> LogitsProcessorList: return logits_processor + def safe_save_model_for_hf_trainer(trainer: Trainer, output_dir: str): """Collects the state dict and dump to disk.""" - state_dict = trainer.model.state_dict() - if not is_deepspeed_zero3_enabled() and trainer.args.should_save: - cpu_state_dict = { - key: value.cpu() - for key, value in state_dict.items() - } - del state_dict - trainer._save(output_dir, state_dict=cpu_state_dict) # noqa - elif is_deepspeed_zero3_enabled(): - # save for deepspeed ZeRO3 checkpoint - if not trainer.wrapped_model.save_16bit_model(output_dir): - trainer.wrapped_model.save_checkpoint(output_dir, save_latest=True) \ No newline at end of file + trainer.save_model(output_dir) + + # state_dict = trainer.model.state_dict() + # if not is_deepspeed_zero3_enabled() and trainer.args.should_save: + # cpu_state_dict = { + # key: value.cpu() + # for key, value in state_dict.items() + # } + # del state_dict + # trainer._save(output_dir, state_dict=cpu_state_dict) # noqa + # elif is_deepspeed_zero3_enabled(): + # # save for deepspeed ZeRO3 checkpoint + # if not trainer.wrapped_model.save_16bit_model(output_dir): + # trainer.wrapped_model.save_checkpoint(output_dir, save_latest=True) \ No newline at end of file