Skip to content

Commit

Permalink
[LLM] support download ckpt from pdc (#9443)
Browse files Browse the repository at this point in the history
  • Loading branch information
SylarTiaNII authored Nov 21, 2024
1 parent b3d5e92 commit a8803d7
Show file tree
Hide file tree
Showing 6 changed files with 268 additions and 18 deletions.
20 changes: 10 additions & 10 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@
SAFE_PEFT_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_INDEX_NAME,
)
from ..utils.fault_tolerance import LOSS_INF_ERROR, LOSS_NAN_ERROR
from ..utils.import_utils import is_datasets_available, is_paddle_cuda_available
from ..utils.log import MetricsDumper, logger
from .argparser import strtobool
Expand Down Expand Up @@ -137,6 +138,7 @@
ShardingOption,
TrainerMemoryTracker,
TrainOutput,
download_recovery_ckpt_from_pdc,
find_batch_size,
get_last_checkpoint,
get_scheduler,
Expand Down Expand Up @@ -186,16 +188,6 @@
except:
from paddle.fluid.dataloader.dataloader_iter import _DataLoaderIterBase

try:
from paddle.framework.recall_error import LOSS_NAN_ERROR
except ImportError:
LOSS_NAN_ERROR = "PaddleRecall error(102): LossNan"

try:
from paddle.framework.recall_error import LOSS_INF_ERROR
except ImportError:
LOSS_INF_ERROR = "PaddleRecall error(104): LossInf"


__all__ = ["Trainer"]

Expand Down Expand Up @@ -694,6 +686,14 @@ def train(
os.makedirs(resume_from_checkpoint, exist_ok=True)
logger.info(f"Reset resume_from_checkpoint to temp directory : {resume_from_checkpoint}")

if resume_from_checkpoint is not None and self.args.pdc_download_ckpt:
if self.is_local_process_zero():
download_recovery_ckpt_from_pdc(resume_from_checkpoint, self.args.pdc_download_timeout)
if self.args.world_size > 1:
logger.info("Wait all processes finish downloading...")
paddle.distributed.barrier()
logger.info("All processes finished downloading from pdc")

# memory metrics - must set up as early as possible
self._memory_tracker.start()
if not self.args.should_load_sharding_stage1_model:
Expand Down
34 changes: 34 additions & 0 deletions paddlenlp/trainer/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@

from ..trainer.argparser import strtobool
from ..transformers.tokenizer_utils_base import BatchEncoding
from ..utils.fault_tolerance import PDC_DOWNLOAD_ERROR
from ..utils.import_utils import is_paddle_cuda_available, is_psutil_available
from ..utils.log import logger
from ..utils.pdc_sdk import PDCErrorCode, PDCErrorMessageMap, pdc_tool

__all__ = [
"TrainOutput",
Expand Down Expand Up @@ -1073,3 +1075,35 @@ def set_hyrbid_parallel_seed(basic_seed, dataset_rank, tp_rank, pp_rank=0):
tracker.add("global_seed", global_seed)
if "local_seed" not in tracker.states_ and local_seed not in tracker.seeds_:
tracker.add("local_seed", local_seed)


def download_recovery_ckpt_from_pdc(recovery_checkpoint_path, timeout):
"""Download checkpoint from PDC for resuming training after failover. Longjob envrionment is necessary.
Args:
recovery_checkpoint_path (`str`):
local path to load checkpoint for training recovery
timeout (`int`):
max wait time for download
"""

try:
base_dir, download_dir = os.path.split(os.path.normpath(recovery_checkpoint_path))
if not os.path.exists(base_dir) and base_dir != "":
os.makedirs(base_dir, exist_ok=True)
download_step = int(_re_checkpoint.search(download_dir).groups()[0])
except Exception as e:
raise RuntimeError(f"{PDC_DOWNLOAD_ERROR}; Failed to parse checkpoint path, details: {e}")
start_time = time.time()
result = pdc_tool.pdc_download_checkpoint(download_step, timeout)
end_time = time.time()
if result == PDCErrorCode.Success:
logger.info(f"Successfully downloaded checkpoint from PDC, total time cost: {end_time - start_time} seconds.")
elif result == PDCErrorCode.LocalPathExist:
logger.warning(
f"Skipping download checkpoint since file exists at local, total time cost: {end_time - start_time} seconds."
)
else:
raise RuntimeError(
f"{PDC_DOWNLOAD_ERROR}; Error occurred when trying to download checkpoint from PDC, recovery_checkpoint_path: {recovery_checkpoint_path}, timeout: {timeout}; error details: {PDCErrorMessageMap[result]}"
)
17 changes: 17 additions & 0 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import paddle.distributed as dist
from paddle.distributed import fleet

from ..utils.fault_tolerance import is_ft_env
from ..utils.log import logger
from .trainer_utils import (
IntervalStrategy,
Expand Down Expand Up @@ -856,6 +857,14 @@ class TrainingArguments:
save_sharding_stage1_model_include_freeze_params: Optional[bool] = field(
default=False, metadata={"help": "Save Sharding Stage1 Model Exclude Freeze Params"}
)
pdc_download_ckpt: Optional[bool] = field(
default=False,
metadata={"help": "Download checkpoint in paddlecloud longjob environment"},
)
pdc_download_timeout: Optional[int] = field(
default=300,
metadata={"help": "Timeout seconds for downloading checkpoint from remote cluster."},
)

def __post_init__(self):
env_local_rank = int(os.environ.get("PADDLE_RANK_IN_NODE", -1))
Expand Down Expand Up @@ -1658,6 +1667,14 @@ def is_segment_parallel_supported():
f"The local_ran: {self.local_rank} should be consistent with the world size: {paddle.distributed.get_world_size()}."
)

# process fault tolerance settings
if not is_ft_env():
if self.pdc_download_ckpt:
logger.warning(
"pdc_download_ckpt can only be set as true inside FT environment. Automatically disable it now."
)
self.pdc_download_ckpt = False

def __str__(self):
self_as_dict = asdict(self)
self_as_dict = {k: f"<{k.upper()}>" if k.endswith("_token") else v for k, v in self_as_dict.items()}
Expand Down
36 changes: 36 additions & 0 deletions paddlenlp/utils/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@
from tqdm.auto import tqdm

from .env import DOWNLOAD_SERVER, FAILED_STATUS, SUCCESS_STATUS
from .fault_tolerance import PDC_DOWNLOAD_ERROR
from .log import logger
from .pdc_sdk import PDCErrorCode, PDCErrorMessageMap, pdc_tool

__all__ = ["get_weights_path_from_url"]

Expand Down Expand Up @@ -469,3 +471,37 @@ def hf_file_exists(
return True
except EntryNotFoundError:
return False


def download_from_pdc(remote_path, local_path, timeout):
"""Download from remote_path and place to a local_path through PaddleCloud. remote_path has to be uploaded through PaddleCloud as well.
Args:
remote_path (`str`):
remote path url for download
local_path (`str`):
local path to place downloaded object
timeout (`int`):
max wait time for download
"""

try:
base_dir, _ = os.path.split(os.path.normpath(remote_path))
if not os.path.exists(base_dir) and base_dir != "":
os.makedirs(base_dir, exist_ok=True)
except Exception as e:
raise RuntimeError(f"{PDC_DOWNLOAD_ERROR}; Failed to parse checkpoint path, details: {e}")
start_time = time.time()
result = pdc_tool.pdc_download(remote_path, local_path, timeout)
end_time = time.time()
if result == PDCErrorCode.Success:
logger.info(f"Successfully downloaded object from PDC, total time cost: {end_time - start_time} seconds.")
elif result == PDCErrorCode.LocalPathExist:
logger.warning(
f"Skipping download object since file exists at local, total time cost: {end_time - start_time} seconds."
)
else:
raise RuntimeError(
f"{PDC_DOWNLOAD_ERROR}; Error occurred when trying to download object from PDC, remote_path: {remote_path}, local_path: {local_path}, timeout: {timeout}; error details: {PDCErrorMessageMap[result]}"
)
34 changes: 34 additions & 0 deletions paddlenlp/utils/fault_tolerance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

try:
from paddle.framework.recall_error import LOSS_NAN_ERROR
except ImportError:
LOSS_NAN_ERROR = "PaddleRecall error(102): LossNan"

try:
from paddle.framework.recall_error import LOSS_INF_ERROR
except ImportError:
LOSS_INF_ERROR = "PaddleRecall error(104): LossInf"

PDC_DOWNLOAD_ERROR = "PaddleRecall error(105): PDCDownloadError"


def is_ft_env():
"""
Check if the current environment is a FT environment.
"""
return "PDC_LONGJOB_ID" in os.environ
Loading

0 comments on commit a8803d7

Please sign in to comment.