Skip to content

Commit

Permalink
Improve error handling for s3 read errors. (#273)
Browse files Browse the repository at this point in the history
* Improve error handling.

* Renamed args.

* Bugfix.

* Another small bugfix.

* Small relaxation in tests.

* Another fix in tests.

* Revert tests and update expected steps count.

* Fix error in final ckpt

* Fix name

* Expand on elements being saved in stats file.

* Formatting

* Add more detailed logging on next shard per source.

* Even more detailed logging.

* Add percent of data seen.

---------

Co-authored-by: George Smyrnis <[email protected]>
  • Loading branch information
GeorgiosSmyrnis and GeorgiosSmyrnis authored May 14, 2024
1 parent b47fd05 commit 3b4a063
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 0 deletions.
50 changes: 50 additions & 0 deletions open_lm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
check_exists,
start_sync_process,
remote_sync_with_expon_backoff,
get_metadata_file,
get_string_for_epoch,
log_num_checkpoints,
terminate_sync_process,
Expand Down Expand Up @@ -203,6 +204,7 @@ def save_checkpoint(
next_shard_per_source=None,
samples_seen=None,
shard_shuffle_seed=None,
train_data_string=None,
averagers=None,
):
cpu_state, optim_state = None, None
Expand Down Expand Up @@ -246,6 +248,20 @@ def save_checkpoint(
"is_final_checkpoint": is_final_checkpoint,
"evaluation_metrics": evaluation_metrics,
}
if next_shard_per_source is not None:
checkpoint_dict_stats["next_shard_per_source"] = next_shard_per_source

if samples_seen is not None:
checkpoint_dict_stats["samples_seen"] = samples_seen

if step is not None:
checkpoint_dict_stats["step"] = step

if shard_shuffle_seed is not None:
checkpoint_dict_stats["shard_shuffle_seed"] = shard_shuffle_seed

if train_data_string is not None:
checkpoint_dict_stats["train_data_string"] = train_data_string

prefixes = {
"epoch_": checkpoint_dict_model,
Expand Down Expand Up @@ -752,6 +768,7 @@ def main(args):
# Only enter training loop if there are steps to be done.
done_training = global_step >= total_steps
epoch = start_epoch
num_ckpt_too_few_tokens = 0
while not done_training:
if is_master(args):
logging.info(f"Start epoch {epoch}")
Expand Down Expand Up @@ -823,6 +840,15 @@ def main(args):
logging.info("Training exiting due to NaN value")
break

expected_steps = data["train"].dataloader.num_batches
if steps_done_epoch < (1 - args.data_tolerate_error_p) * expected_steps and not done_training:
num_ckpt_too_few_tokens += 1

if num_ckpt_too_few_tokens > args.data_tolerate_num_ckpts:
raise RuntimeError(
f"{num_ckpt_too_few_tokens} checkpoints happened where the number of tokens seen was less than {1 - args.data_tolerate_error_p} of expected. This is likely due to transient errors e.g. reading from S3."
)

epoch = epoch + 1
evaluation_metrics = []
if "val_list" in data and (epoch % args.val_frequency == 0 or done_training):
Expand All @@ -840,6 +866,29 @@ def main(args):
logging.error(traceback.format_exc())
logging.warning("evaluation failed! continuing to save_checkpoint")

if is_master(args):
end_of_epoch_log = {
"epoch": epoch,
"tokens": (global_step + 1) * args.global_batch_size * args.seq_len,
"checkpoints_too_few_tokens": num_ckpt_too_few_tokens,
"percentage_of_data_seen": steps_done_epoch / expected_steps,
}

if args.dataset_manifest is not None:
for i in range(len(next_shard_per_source)):
end_of_epoch_log[f"next_shard_{i}"] = next_shard_per_source[i]
end_of_epoch_log[f"dataset_pass_{i}"] = next_shard_per_source[i] // len(
get_metadata_file(args.dataset_manifest[i])
)

for name, val in end_of_epoch_log.items():
name = "train/" + name
if writer is not None:
writer.add_scalar(name, val, global_step)
if args.wandb:
assert wandb is not None, "Please install wandb."
wandb.log({name: val, "step": global_step, "tokens": end_of_epoch_log["tokens"]})

# Saving checkpoints.
save_checkpoint(
args,
Expand All @@ -853,6 +902,7 @@ def main(args):
next_shard_per_source=next_shard_per_source if args.dataset_manifest is not None else None,
samples_seen=samples_seen if args.dataset_manifest is not None else None,
shard_shuffle_seed=args.shard_shuffle_seed,
train_data_string=train_data_string_per_source if args.dataset_manifest is not None else None,
averagers=averagers,
)

Expand Down
13 changes: 13 additions & 0 deletions open_lm/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,19 @@ def parse_args(args):
default=0,
help="Whether to log the average model training loss. if not 0, it will log the average loss over the specified number of steps.",
)
parser.add_argument(
"--data-tolerate-error-p",
type=float,
default=0.09, # Roughly the number required to not repeat more than 10% of data.
help="This is the percentage of expected tokens above which the checkpoint is considered failed because of not having seen enough data.",
)
parser.add_argument(
"--data-tolerate-num-ckpts",
type=int,
default=0,
help="This is the maximum number of failed checkpoints (due to not having seen enough tokens) that are allowed",
)

add_model_args(parser)

config = maybe_load_config(parser, args)
Expand Down

0 comments on commit 3b4a063

Please sign in to comment.