From 854cfc771119ce94d3cb74d4e3d6b53b164bf802 Mon Sep 17 00:00:00 2001 From: Sam Foreman Date: Mon, 23 Dec 2024 13:55:02 -0600 Subject: [PATCH 1/5] chore: Fix printing in `ALCF/helpers.sh` --- ALCF/helpers.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ALCF/helpers.sh b/ALCF/helpers.sh index a52fc2bbb1..64ab51cc0a 100644 --- a/ALCF/helpers.sh +++ b/ALCF/helpers.sh @@ -922,8 +922,8 @@ buildDSconfig() { # export CPU_OPTIMIZER="${CPU_OPTIMIZER:-0}" export DS_CONFIG="${WORKING_DIR}/ds-configs/ds_stage${ZERO_STAGE}_mb${MICRO_BATCH}_gb${GLOBAL_BATCH}_pp${PP}_${DTYPE}.json" mkdir -p "$(dirname "${DS_CONFIG}")" - echo "DS_CONFIG: ${DS_CONFIG}" - printf "ZS: %s, MB: %s, GB: %s, PP: %s, DTYPE: %s" "${ZERO_STAGE}" "${MICRO_BATCH}" "${GLOBAL_BATCH}" "${PP}" "${DTYPE}" + printf "DS_CONFIG: %s\n" "${DS_CONFIG}" + printf "ZS=%s, MB=%s, GB=%s, PP=%s, DTYPE=%s\n" "${ZERO_STAGE}" "${MICRO_BATCH}" "${GLOBAL_BATCH}" "${PP}" "${DTYPE}" generateDSconfig "${DS_CONFIG}" cat "${DS_CONFIG}" | jq . } From 7a6a5bd6b14da005d39dbe348906b45c27656281 Mon Sep 17 00:00:00 2001 From: Sam Foreman Date: Mon, 23 Dec 2024 13:55:45 -0600 Subject: [PATCH 2/5] chore: Alphabetize optimizer(s) in `megatron/arguments.py` --- megatron/arguments.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/megatron/arguments.py b/megatron/arguments.py index 9b0e6ccb1a..cef72acf51 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -954,20 +954,22 @@ def _add_training_args(parser): default='adam', choices=[ 'adam', + 'adam8bit', 'adamw', - 'sophiag', - 'sgd', - 'ds.fusedlamb', - 'ipex.lamb', - 'ipex.fusedlamb', + 'adamwschedulefree', 'apex.adam', 'apex.sgd', - 'adamwschedulefree', - 'sgdschedulefree', + 'ds.fusedlamb', + 'ds.onebitlamb', 'galoreadamw', - 'adam8bit', 'galoreadamw8bit', - 'galoreadamw8bitperlayer' + 'galoreadamw8bitperlayer', + 'ipex.fusedlamb', + 'ipex.lamb', + 'shampoo', + 'sgd', + 'sgdschedulefree', + 'sophiag' ], help='Optimizer function' ) From 37194a6bb0139be9aa28ad8ee106f3f9ff46a167 Mon Sep 17 00:00:00 2001 From: Sam Foreman Date: Mon, 23 Dec 2024 13:56:09 -0600 Subject: [PATCH 3/5] chore: Update `megatron/optimizer/__init__.py` --- megatron/optimizer/__init__.py | 205 +++++++++++++++++++-------------- 1 file changed, 121 insertions(+), 84 deletions(-) diff --git a/megatron/optimizer/__init__.py b/megatron/optimizer/__init__.py index 99145ff4f4..37107a44e7 100644 --- a/megatron/optimizer/__init__.py +++ b/megatron/optimizer/__init__.py @@ -127,6 +127,8 @@ def get_megatron_optimizer( param_groups ) + optimizer = None + # ---- CPU Optimizer -------------------------------------- if args.cpu_optimizer: assert args.optimizer == 'adam', 'CPU offloading is for Adam' if args.cpu_torch_adam: @@ -141,52 +143,73 @@ def get_megatron_optimizer( betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_eps, ) - - elif args.optimizer.lower() == "galore_adamw": - from galore_torch import GaLoreAdamW, GaLoreAdamW8bit - # redefine way to call galore_adamw - optimizer = GaLoreAdamW(param_groups, lr=args.lr, weight_decay=args.weight_decay) - elif args.optimizer.lower() == "galore_adamw": - # redefine way to call galore_adamw - optimizer = GaLoreAdamW(param_groups, lr=args.lr, weight_decay=args.weight_decay) - # implement adafactor - elif args.optimizer.lower() == "adafactor": - import transformers - args.beta1 = None if args.beta1 == 0.0 else args.beta1 - optimizer = transformers.optimization.Adafactor( + # ---- Adam -------------------------------------- + elif args.optimizer == 'adam': + if args.ds_fused_adam: + # global Adam + from deepspeed.ops.adam import FusedAdam + Adam = FusedAdam + else: + Adam = torch.optim.Adam + optimizer = Adam( param_groups, lr=args.lr, - eps=(1e-30, 1e-3), - clip_threshold=1.0, - decay_rate=-0.8, - beta1=args.beta1, weight_decay=args.weight_decay, - relative_step=False, - scale_parameter=False, - warmup_init=False, + betas=(args.adam_beta1, args.adam_beta2), + eps=args.adam_eps ) - # low-rank adafactor - elif args.optimizer.lower() == "galore_adafactor": - args.beta1 = None if args.beta1 == 0.0 else args.beta1 - optimizer = GaLoreAdafactor( + # ---- apex.Adam -------------------------------------------- + elif str(args.optimizer).lower() == 'apex.adam': + assert get_accelerator().device_name() == 'cuda' + from apex.optimizers import FusedAdam as Adam + optimizer = Adam( param_groups, lr=args.lr, - eps=(1e-30, 1e-3), - clip_threshold=1.0, - decay_rate=-0.8, - beta1=args.beta1, weight_decay=args.weight_decay, - relative_step=False, - scale_parameter=False, - warmup_init=False, + betas=(args.adam_beta1, args.adam_beta2), + eps=args.adam_eps ) - # 8-bit Adam + # ---- Adam8Bit -------------------------------------- elif args.optimizer.lower() == "adam8bit": import bitsandbytes as bnb optimizer = bnb.optim.Adam8bit(param_groups, lr=args.lr, weight_decay=args.weight_decay) + # ---- AdamW -------------------------------------- + elif str(args.optimizer).lower() == 'adamw': + optimizer = torch.optim.AdamW( + param_groups, + lr=args.lr, + weight_decay=args.weight_decay, + betas=(args.adam_beta1, args.adam_beta2), + eps=args.adam_eps + ) + # ---- AdamW: ScheduleFree ------------------------------------- + elif str(args.optimizer).lower() == 'adamwschedulefree': + import schedulefree + optimizer = schedulefree.AdamWScheduleFree( + param_groups, + lr=args.lr, + weight_decay=args.weight_decay, + betas=(args.adam_beta1, args.adam_beta2), + eps=args.adam_eps, + warmup_steps=args.lr_warmup_iters, + foreach=args.schedulefree_for_each, + ) + # ---- AdamW: Galore ------------------------------------------ + elif args.optimizer.lower() == "galore_adamw": + from galore_torch import GaLoreAdamW + # redefine way to call galore_adamw + optimizer = GaLoreAdamW(param_groups, lr=args.lr, weight_decay=args.weight_decay) + # elif args.optimizer.lower() == "galore_adamw": + # from galore_torch import GaLoreAdamW + # # redefine way to call galore_adamw + # optimizer = GaLoreAdamW(param_groups, lr=args.lr, weight_decay=args.weight_decay) + # ---- AdamW: GaloRe 8Bit -------------------------------------- elif args.optimizer.lower() == "galore_adamw8bit": + from galore_torch import GaLoreAdamW8bit optimizer = GaLoreAdamW8bit(param_groups, lr=args.lr, weight_decay=args.weight_decay) + # ---- AdamW8bitPerLayer: GaloRE ---------------------------- elif args.optimizer.lower() == 'galore_adamw8bit_per_layer': + from galore_torch import GaLoreAdamW8bit # TODO: seems scheduler call twice in one update step, need to check, for now double the num_training_steps, warmup_steps and update_proj_gap optimizer_dict = {} for p in model.parameters(): @@ -219,45 +242,48 @@ def optimizer_hook(p): if p.requires_grad: p.register_post_accumulate_grad_hook(optimizer_hook) layer_wise_flag = True - elif str(args.optimizer) == 'ipex.lamb': - from intel_extension_for_pytorch.optim._lamb import Lamb - optimizer = Lamb( - param_groups, - lr=args.lr, - weight_decay=args.weight_decay, - betas=(args.adam_beta1, args.adam_beta2), - eps=args.adam_eps, - ) - elif str(args.optimizer) == 'ipex.fusedlamb': - from intel_extension_for_pytorch.optim._lamb import Lamb - optimizer = Lamb( + # ---- AdaFactor -------------------------------------- + elif args.optimizer.lower() == "adafactor": + import transformers + args.beta1 = None if args.beta1 == 0.0 else args.beta1 + optimizer = transformers.optimization.Adafactor( param_groups, lr=args.lr, + eps=(1e-30, 1e-3), + clip_threshold=1.0, + decay_rate=-0.8, + beta1=args.beta1, weight_decay=args.weight_decay, - betas=(args.adam_beta1, args.adam_beta2), - eps=args.adam_eps, - fused=True, + relative_step=False, + scale_parameter=False, + warmup_init=False, ) - elif str(args.optimizer).lower() == 'ds.fusedlamb': - from deepspeed.ops.lamb import FusedLamb - optimizer = FusedLamb( + # ---- GaLore: Adafactor adafactor ------------------------------------ + elif args.optimizer.lower() == "galore_adafactor": + from galore_torch import GaLoreAdafactor + args.beta1 = None if args.beta1 == 0.0 else args.beta1 + optimizer = GaLoreAdafactor( param_groups, lr=args.lr, + eps=(1e-30, 1e-3), + clip_threshold=1.0, + decay_rate=-0.8, + beta1=args.beta1, weight_decay=args.weight_decay, - betas=(args.adam_beta1, args.adam_beta2), - eps=args.adam_eps, + relative_step=False, + scale_parameter=False, + warmup_init=False, ) - elif str(args.optimizer).lower() == 'adamwschedulefree': - import schedulefree - optimizer = schedulefree.AdamWScheduleFree( + # ---- Apex: sgd --------------------------------------------- + elif str(args.optimizer).lower() == 'apex.sgd': + from apex.optimizers import FusedSGD as SGD + optimizer = SGD( param_groups, lr=args.lr, weight_decay=args.weight_decay, - betas=(args.adam_beta1, args.adam_beta2), - eps=args.adam_eps, - warmup_steps=args.lr_warmup_iters, - foreach=args.schedulefree_for_each, + momentum=args.sgd_momentum ) + # ---- ScheduleFree: SGD ------------------------------- elif str(args.optimizer).lower() == 'sgdschedulefree': import schedulefree optimizer = schedulefree.SGDScheduleFree( @@ -268,45 +294,54 @@ def optimizer_hook(p): warmup_steps=args.lr_warmup_iters, foreach=args.schedulefree_for_each, ) - elif str(args.optimizer).lower() == 'apex.adam': - assert get_accelerator().device_name() == 'cuda' - from apex.optimizers import FusedAdam as Adam - optimizer = Adam( + # ---- Lamb: Ipex -------------------------------------------- + elif str(args.optimizer) == 'ipex.lamb': + from intel_extension_for_pytorch.optim._lamb import Lamb + optimizer = Lamb( param_groups, lr=args.lr, weight_decay=args.weight_decay, betas=(args.adam_beta1, args.adam_beta2), - eps=args.adam_eps - ) - elif str(args.optimizer).lower() == 'apex.sgd': - from apex.optimizers import FusedSGD as SGD - optimizer = SGD( - param_groups, - lr=args.lr, - weight_decay=args.weight_decay, - momentum=args.sgd_momentum + eps=args.adam_eps, ) - elif str(args.optimizer).lower() == 'adamw': - optimizer = torch.optim.AdamW( + # ---- Lamb(Fused): Ipex ---------------------------------------- + elif str(args.optimizer) == 'ipex.fusedlamb': + from intel_extension_for_pytorch.optim._lamb import Lamb + optimizer = Lamb( param_groups, lr=args.lr, weight_decay=args.weight_decay, betas=(args.adam_beta1, args.adam_beta2), - eps=args.adam_eps + eps=args.adam_eps, + fused=True, ) - elif args.optimizer == 'adam': - if args.ds_fused_adam: - # global Adam - from deepspeed.ops.adam import FusedAdam - Adam = FusedAdam - else: - Adam = torch.optim.Adam - optimizer = Adam( + # ---- Lamb(Fused): DeepSpeed ------------------------------------------ + elif str(args.optimizer).lower() == 'ds.fusedlamb': + from deepspeed.ops.lamb import FusedLamb + optimizer = FusedLamb( param_groups, lr=args.lr, weight_decay=args.weight_decay, betas=(args.adam_beta1, args.adam_beta2), - eps=args.adam_eps + eps=args.adam_eps, + ) + # ---- Shampoo ---------------------------------------- + elif args.optimizer == 'shampoo': + from distributed_shampoo.distributed_shampoo import DistributedShampoo + from distributed_shampoo.shampoo_types import AdamGraftingConfig + optimizer = DistributedShampoo( + model.parameters(), + lr=0.001, + betas=(0.9, 0.999), + epsilon=1e-12, + weight_decay=1e-05, + max_preconditioner_dim=8192, + precondition_frequency=100, + use_decoupled_weight_decay=True, + grafting_config=AdamGraftingConfig( + beta2=0.999, + epsilon=1e-08, + ), ) elif args.optimizer == 'sgd': optimizer = torch.optim.SGD( @@ -326,8 +361,10 @@ def optimizer_hook(p): ) else: raise TypeError(f'{args.optimizer} optimizer is not supported.') + assert optimizer is not None if args.deepspeed: return optimizer + # Determine whether the params have main-grad field. params_have_main_grad = False if args.use_contiguous_buffers_in_local_ddp: From 3413625eae82072e43cb4125fd2f8cd43282bda4 Mon Sep 17 00:00:00 2001 From: Sam Foreman Date: Mon, 23 Dec 2024 13:56:29 -0600 Subject: [PATCH 4/5] chore: Remove blankline in `pretrain_gpt_alcf.py` --- pretrain_gpt_alcf.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pretrain_gpt_alcf.py b/pretrain_gpt_alcf.py index 3686c6ceeb..c196b19a42 100644 --- a/pretrain_gpt_alcf.py +++ b/pretrain_gpt_alcf.py @@ -1,7 +1,6 @@ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. """Pretrain GPT""" - import time from typing import Callable from mpi4py import MPI From 1504bd011fcda3728248eb017151d29d81f5668e Mon Sep 17 00:00:00 2001 From: Sam Foreman Date: Mon, 23 Dec 2024 13:57:02 -0600 Subject: [PATCH 5/5] feat: Use logging in `tools/hf2megads_weight_converter.py` --- tools/hf2megads_weight_converter.py | 84 +++++++++++++++++------------ 1 file changed, 50 insertions(+), 34 deletions(-) diff --git a/tools/hf2megads_weight_converter.py b/tools/hf2megads_weight_converter.py index 12468963c5..76d40ce252 100755 --- a/tools/hf2megads_weight_converter.py +++ b/tools/hf2megads_weight_converter.py @@ -2,15 +2,18 @@ import re import sys import os +import ezpz as ez + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import torch.distributed from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP -from megatron import print_rank_0, get_tokenizer, get_args +from megatron import get_tokenizer, get_args +from megatron.tokenizer.tokenizer import _vocab_size_with_padding from megatron.core import mpu from megatron.core import tensor_parallel from megatron.core.utils import divide from megatron.model import GPTModelPipe, Float16Module -from megatron.utils import unwrap_model +from megatron.utils import unwrap_model, get_logger from megatron.model import DistributedDataParallel as LocalDDP from megatron.arguments import core_transformer_config_from_args from megatron.initialize import initialize_megatron @@ -22,6 +25,8 @@ import copy from pathlib import Path +RANK = ez.setup_torch('deepspeed') +log = get_logger(__name__, rank_zero_only=True) def add_extra_args(parser): @@ -52,7 +57,7 @@ def compute_partition_range(hidden_size, local_rank, tp_size): def load_and_print_hf_weight(hf_ckpt_dir, hf_ckpt_num_of_shards): # Optimization point: We can selectively load specific 'shared' data to reduce CPU memory usage. loaded = {} - print_rank_0( + log.info( f"----------------------------hf weight list----------------------------") for wid in range(1, hf_ckpt_num_of_shards + 1): @@ -60,7 +65,7 @@ def load_and_print_hf_weight(hf_ckpt_dir, hf_ckpt_num_of_shards): f"{hf_ckpt_dir}/pytorch_model-{wid:05d}-of-{hf_ckpt_num_of_shards:05d}.bin", map_location=torch.device('cpu')) for k in d: - print_rank_0(k) + log.info(k) assert k not in loaded loaded[k] = d[k].clone() del d @@ -71,7 +76,7 @@ def load_and_print_hf_weight_from_safetensor(hf_ckpt_dir, hf_ckpt_num_of_shards) from safetensors import safe_open # Optimization point: We can selectively load specific 'shared' data to reduce CPU memory usage. hf_model = {} - print_rank_0( + log.info( f"----------------------------hf weight list----------------------------") for wid in range(1, hf_ckpt_num_of_shards + 1): @@ -82,7 +87,7 @@ def load_and_print_hf_weight_from_safetensor(hf_ckpt_dir, hf_ckpt_num_of_shards) with safe_open(ckpt_path, framework="pt", device="cpu") as f: for k in f.keys(): - print_rank_0(f"name: {k}, shape: {f.get_tensor(k).shape}") + log.info(f"name: {k}, shape: {f.get_tensor(k).shape}") assert k not in hf_model hf_model[k] = f.get_tensor(k).clone() @@ -100,18 +105,18 @@ def load_and_print_hf_weight_auto(hf_ckpt_dir, no_init=True): else: hf_model = {} hf_auto_model = AutoModelForCausalLM.from_pretrained(hf_ckpt_dir, trust_remote_code=True, torch_dtype=torch.bfloat16) - print_rank_0( + log.info( f"----------------------------hf weight list----------------------------") for name, param in hf_auto_model.named_parameters(): hf_model[name] = param.clone() - print_rank_0(name) + log.info(name) return hf_model def print_distinct_weights(model): - print_rank_0( + log.info( f"----------------------------mega-ds weight list----------------------------") for pipe_rank in range(mpu.get_pipeline_model_parallel_world_size()): if mpu.get_pipeline_model_parallel_rank() == pipe_rank: @@ -154,7 +159,11 @@ def _embedding_refactor(self, pname, p): elif pname == f"{self.mega_emb_wnum}.word_embeddings.weight": hf_name = "model.embed_tokens.weight" hf_w = self.hf_model[hf_name] - assert hf_w.shape[0] == self.token_vocab + log.info(f"{hf_w.shape[0]=}") + log.info(f"{self.token_vocab=}") + if hf_w.shape[0] != self.padded_vocab_size: + torch.distributed.breakpoint(0) + assert hf_w.shape[0] == self.padded_vocab_size per_partition_vocab_size, start_index, end_index = compute_partition_range( self.padded_vocab_size, self.tp_rank, self.tp_size) end_index = min(end_index, self.token_vocab) @@ -170,9 +179,6 @@ def _embedding_refactor(self, pname, p): ) return new_w - - - def _direct_refactor(self, pname, p, hf_layer=None, subname=None): if pname == f"{self.mega_norm_wnum}.weight": hf_name = "model.norm.weight" @@ -203,16 +209,26 @@ def _qkv_refactor(self, pname, p, hf_layer): new_w = torch.zeros((per_partition_size * 3, wq.shape[1]), dtype=wq.dtype) + # >>> pname + # '2.self_attention.query_key_value.weight' + # >>> p.shap^U + # '2.self_attention.query_key_value.weight' + # >>> xp = p + # >>> xp.shape + # torch.Size([6144, 4096]) for i in range(num_attention_heads_per_partition): - current_index = start_index + i * hidden_size_per_attention_head - next_index = current_index + hidden_size_per_attention_head - new_w_index = i * (3 * hidden_size_per_attention_head) - new_w[new_w_index: new_w_index + (3 * hidden_size_per_attention_head), :] = \ - torch.cat([ - wq[current_index: next_index, :], - wk[current_index: next_index, :], - wv[current_index: next_index, :] - ], dim=0) + try: + current_index = start_index + i * hidden_size_per_attention_head + next_index = current_index + hidden_size_per_attention_head + new_w_index = i * (3 * hidden_size_per_attention_head) + new_w[new_w_index: new_w_index + (3 * hidden_size_per_attention_head), :] = \ + torch.cat([ + wq[current_index: next_index, :], + wk[current_index: next_index, :], + wv[current_index: next_index, :] + ], dim=0) + except Exception: + torch.distributed.breakpoint(0) self.record_mapping_info( f"mega-ds:{pname,p.data.shape}<--hf{hf_wq_name,hf_wk_name,hf_wv_name,} cat q,k,v [{current_index}:{next_index},:] of q,k,v{wq.shape}" ) @@ -275,7 +291,7 @@ def _mlphto4h1_refactor(self, pname, p, hf_layer, subname): return new_w def transform_from_hf_to_megds(self): - assert self.is_refactored == False + assert not self.is_refactored new_w = None for pname, p in self.ds_model.named_parameters(): @@ -288,6 +304,7 @@ def transform_from_hf_to_megds(self): new_w = self._direct_refactor(pname, p) else: mobj = self.decoder_pat.match(pname) + assert mobj is not None layer_num = int(mobj.group(1)) subname = mobj.group(2) hf_layer = layer_num - self.offset_num @@ -316,7 +333,7 @@ def transform_from_hf_to_megds(self): new_w = None self.is_refactored = True - + def _embedding_refactor_to_hf(self, pname, ds_w): if pname == f"{self.mega_lm_head_wnum}.lm_head.weight": hf_w = self.hf_model.lm_head.weight @@ -327,7 +344,7 @@ def _embedding_refactor_to_hf(self, pname, ds_w): with torch.no_grad(): ds_w_all_rank = tensor_parallel.mappings._gather_along_first_dim(ds_w) - + self.hf_dict[hf_w_name] = copy.deepcopy(ds_w_all_rank[:hf_w.shape[0], :]) def _direct_refactor_to_hf(self, pname, ds_w, hf_layer=None, subname=None): @@ -438,7 +455,7 @@ def record_mapping_info(self, record_msg): def inorder_show_record(self): assert self.is_refactored - print_rank_0( + log.info( f"----------------------------mapping list----------------------------") # print dp rank0 tp rank0 records. for pipe_rank in range(mpu.get_pipeline_model_parallel_world_size()): @@ -466,7 +483,7 @@ def load_hf_weights(args, no_init): def convert_ckpt(): """Build the model.""" args = get_args() - print_rank_0(f'building model ...') + log.info(f'building model ...') see_memory_usage(f"Before Building Model", force=True) config = core_transformer_config_from_args(args) @@ -491,18 +508,18 @@ def convert_ckpt(): # print_distinct_weights(hf_model) #init model and save - print_rank_0(f"before deepspeed init") + log.info(f"before deepspeed init") ds_engine, _, _, _ = deepspeed.initialize( model=ds_model, optimizer=None, args=args, lr_scheduler=None, mpu=mpu if args.no_pipeline_parallel else None) - print_rank_0(f"after deepspeed init") + log.info(f"after deepspeed init") if args.to_hf_ckpt: load_checkpoint([ds_engine], None, None, load_only_weights=True) - print_rank_0(f"completed to load deepspeed actual checkpoint") + log.info(f"completed to load deepspeed actual checkpoint") # refactor weight from hf to mega-ds and vice versa @@ -523,7 +540,7 @@ def convert_ckpt(): if torch.distributed.is_initialized(): torch.distributed.barrier() - print_rank_0(f"hf checkpoint will be saved in {save_path}/release ") + log.info(f"hf checkpoint will be saved in {save_path}/release ") if mpu.is_pipeline_last_stage(): ## doing checkpoint merging and saving... # hf_model.tie_weights() @@ -539,10 +556,9 @@ def convert_ckpt(): # mega-ds checkpoint will be saved in args.save hf_model.save_pretrained(os.path.join(save_path, "release"), safe_serialization=True) else: - print_rank_0(f"mega-ds checkpoint will be saved in {args.save}") + log.info(f"mega-ds checkpoint will be saved in {args.save}") save_checkpoint(0, [ds_engine], None, None) - - print_rank_0(f"save checkpoint completed") + log.info(f"save checkpoint completed") if __name__ == "__main__":