Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into feature/fp8
Browse files Browse the repository at this point in the history
  • Loading branch information
shahromil16 committed Jun 12, 2024
2 parents 29000f3 + c65b430 commit 936cd9a
Show file tree
Hide file tree
Showing 10 changed files with 206 additions and 52 deletions.
97 changes: 65 additions & 32 deletions open_lm/datapreprocess/ray/tokenize_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,28 +276,33 @@ def preprocess(
tokens = tokenizer_fn(string)
tokens.append(EOT)
buffer += tokens
while len(buffer) >= seqlen:
idx = 0
while idx < len(buffer) - seqlen:
if do_sample:
local_sample_freq = sample_freq
# This code does the following
# yield a int(sample_freq) copies of buffer[:seqlen]
# yield a int(sample_freq) copies of the current sample
# then yield 1 more sample with Pr[sample_freq - int(sample_freq)]
# in expectation we will yield sample_freq copies of buffer[:seqlen]
# in expectation we will yield sample_freq copies of the current sample
while local_sample_freq > 1:
if source_counter is not None:
ray.get(source_counter.increment_token_count.remote(seqlen))
yield buffer[:seqlen]
yield buffer[idx : idx + seqlen]
local_sample_freq -= 1
if rng.random() < local_sample_freq:
if source_counter is not None:
ray.get(source_counter.increment_token_count.remote(seqlen))
yield buffer[:seqlen]
buffer = buffer[seqlen:]
yield buffer[idx : idx + seqlen]
idx += seqlen
else:
if source_counter is not None:
ray.get(source_counter.increment_token_count.remote(seqlen))
yield buffer[:seqlen]
buffer = buffer[seqlen:]
yield buffer[idx : idx + seqlen]
idx += seqlen

# Remove the tokens that have been yielded from the buffer
buffer = buffer[idx:]

if len(buffer) > 0:
if source_counter is not None:
ray.get(source_counter.increment_token_count.remote(len(buffer)))
Expand Down Expand Up @@ -570,6 +575,7 @@ def main(args):
) # default is localhost; for slurm jobs do 0.0.0.0
parser.add_argument("--suffixes", nargs="+", default=[".json", ".jsonl", ".zst", ".zstd", ".tar", ".gz"])
parser.add_argument("--presort", action="store_true")
parser.add_argument("--allow_imbalanced_write", action="store_true")

args = parser.parse_args(args)
if args.do_sample:
Expand Down Expand Up @@ -667,33 +673,60 @@ def main(args):
counter = GlobalCounter.remote()
out_folder = args.output.rstrip("/")
# first map buffer_write over rows, it will create an actor (which hopefully will be scheduled locally)
write_status = ds.map_batches(
buffer_write,
fn_kwargs={
"folder": out_folder,
"counter": counter,
"buffer_size": args.wds_chunk_size,
"num_writers_per_node": num_writers_per_node,
},
zero_copy_batch=True,
batch_size=args.wds_chunk_size,
batch_format="pandas",
).take_all()
# after the write is done, grab all actors of class BufferedShardWriter
buffer_writers_names = set(
[x.name for x in list_actors(filters=[("class_name", "=", "BufferedShardWriter"), ("state", "=", "ALIVE")])]
)
buffer_writers = [ray.get_actor(x) for x in buffer_writers_names]
# flush the remaining buffers, this should be the *only* shards that are less than wds_chunk_size
flushed_buffers = [bw._flush_buffer.remote(out_folder, counter) for bw in buffer_writers]
tail_write_status = [ray.get(buf) for buf in flushed_buffers]
# Grab manifests which are stored in the buffer writers
manifests = [manifest_row for bw in buffer_writers for manifest_row in ray.get(bw.get_manifests.remote())]
manifests_sorted = sorted(manifests, key=lambda x: x["shard"])
write_manifest(manifests_sorted, args)
if args.allow_imbalanced_write:
ds = ds.map_batches(
map_write_wds,
batch_size=wds_chunk_size,
fn_kwargs={
"batch_size": wds_chunk_size,
"folder": out_folder,
"counter": counter,
},
batch_format="pandas",
)
ds = ds.repartition(1)
ds = ds.sort(key="shard")
jsonl_lines = ds.take_all()
token_count_from_manifest = sum([x["num_sequences"][0] for x in jsonl_lines] * seqlen)
write_manifest(jsonl_lines, args)
else:
write_status = ds.map_batches(
buffer_write,
fn_kwargs={
"folder": out_folder,
"counter": counter,
"buffer_size": args.wds_chunk_size,
"num_writers_per_node": num_writers_per_node,
},
zero_copy_batch=True,
batch_size=args.wds_chunk_size,
batch_format="pandas",
).take_all()

# after the write is done, grab all actors of class BufferedShardWriter
buffer_writers_names = set(
[x.name for x in list_actors(filters=[("class_name", "=", "BufferedShardWriter"), ("state", "=", "ALIVE")])]
)
buffer_writers = [ray.get_actor(x) for x in buffer_writers_names]
# flush the remaining buffers, this should be the *only* shards that are less than wds_chunk_size
flushed_buffers = [bw._flush_buffer.remote(out_folder, counter) for bw in buffer_writers]
tail_write_status = [ray.get(buf) for buf in flushed_buffers]
# Grab manifests which are stored in the buffer writers
manifests = [manifest_row for bw in buffer_writers for manifest_row in ray.get(bw.get_manifests.remote())]
manifests_sorted = sorted(manifests, key=lambda x: x["shard"])
token_count_from_manifest = sum([x["num_sequences"] for x in manifests_sorted] * seqlen)
write_manifest(manifests_sorted, args)

end_time = time.time()
duration = end_time - start_time
final_token_count = ray.get(counter.increment_token_count.remote(0))

if token_count_from_manifest != final_token_count:
logger.warning(
f"Token count mismatch: {token_count_from_manifest} from manifest vs {final_token_count} global actor. Please run manifest generation manually via make_wds_manifest.py."
)
# TODO: Generate manifest automatically from the tokenized data if token count mismatch

print("==== Token count summary ====")
print(f"Tokenize + Shuffle script Finished in: {duration}")
print(f"Final Token Count: {final_token_count}")
Expand Down
13 changes: 11 additions & 2 deletions open_lm/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,19 @@ def remote_sync(local_dir, remote_dir, protocol):
return False


def remote_sync_with_expon_backoff(sync_every, local_dir, remote_dir, protocol, max_retries=6):
for i in range(max_retries):
time.sleep(sync_every * 2**i)
success = remote_sync(local_dir, remote_dir, protocol)
if success:
return True

return False


def keep_running_remote_sync(sync_every, local_dir, remote_dir, protocol):
while True:
time.sleep(sync_every)
remote_sync(local_dir, remote_dir, protocol)
remote_sync_with_expon_backoff(sync_every, local_dir, remote_dir, protocol)


def start_sync_process(sync_every, local_dir, remote_dir, protocol):
Expand Down
75 changes: 70 additions & 5 deletions open_lm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@
pt_load,
check_exists,
start_sync_process,
remote_sync,
remote_sync_with_expon_backoff,
get_metadata_file,
get_string_for_epoch,
log_num_checkpoints,
terminate_sync_process,
Expand Down Expand Up @@ -199,10 +200,13 @@ def save_checkpoint(
evaluation_metrics,
step,
is_final_checkpoint,
percentage_of_data_seen=-1.0,
next_shard_per_source=None,
samples_seen=None,
shard_shuffle_seed=None,
train_data_string=None,
averagers=None,
failed=False,
):
cpu_state, optim_state = None, None
if args.logs and args.logs.lower() != "none" and args.fsdp:
Expand Down Expand Up @@ -244,7 +248,22 @@ def save_checkpoint(
"name": args.name,
"is_final_checkpoint": is_final_checkpoint,
"evaluation_metrics": evaluation_metrics,
"percentage_of_data_seen": percentage_of_data_seen,
}
if next_shard_per_source is not None:
checkpoint_dict_stats["next_shard_per_source"] = next_shard_per_source

if samples_seen is not None:
checkpoint_dict_stats["samples_seen"] = samples_seen

if step is not None:
checkpoint_dict_stats["step"] = step

if shard_shuffle_seed is not None:
checkpoint_dict_stats["shard_shuffle_seed"] = shard_shuffle_seed

if train_data_string is not None:
checkpoint_dict_stats["train_data_string"] = train_data_string

prefixes = {
"epoch_": checkpoint_dict_model,
Expand All @@ -261,7 +280,8 @@ def save_checkpoint(
or (args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0)
):
for prefix in prefixes:
path = os.path.join(args.checkpoint_path, f"{prefix}{completed_epoch}.pt")
save_path = args.checkpoint_path if not failed else args.failed_checkpoint_path
path = os.path.join(save_path, f"{prefix}{completed_epoch}.pt")
print(f"Saving {prefix}{completed_epoch} in {path}...")
torch.save(
prefixes[prefix],
Expand Down Expand Up @@ -358,9 +378,10 @@ def main(args):
args.wandb = "wandb" in args.report_to or "all" in args.report_to
args.tensorboard = "tensorboard" in args.report_to or "all" in args.report_to
args.checkpoint_path = os.path.join(log_base_path, "checkpoints")
args.failed_checkpoint_path = os.path.join(log_base_path, "checkpoints_failed")
if is_master(args):
args.tensorboard_path = os.path.join(log_base_path, "tensorboard") if args.tensorboard else ""
for dirname in [args.tensorboard_path, args.checkpoint_path]:
for dirname in [args.tensorboard_path, args.checkpoint_path, args.failed_checkpoint_path]:
if dirname:
os.makedirs(dirname, exist_ok=True)
else:
Expand Down Expand Up @@ -403,7 +424,8 @@ def main(args):
remote_sync_process = None
if is_master(args) and args.remote_sync is not None:
# first make sure it works
result = remote_sync(
result = remote_sync_with_expon_backoff(
args.remote_sync_frequency,
os.path.join(args.logs, args.name),
os.path.join(args.remote_sync, args.name),
args.remote_sync_protocol,
Expand Down Expand Up @@ -757,6 +779,7 @@ def main(args):
# Only enter training loop if there are steps to be done.
done_training = global_step >= total_steps
epoch = start_epoch
num_ckpt_too_few_tokens = 0
while not done_training:
if is_master(args):
logging.info(f"Start epoch {epoch}")
Expand Down Expand Up @@ -829,6 +852,16 @@ def main(args):
logging.info("Training exiting due to NaN value")
break

failed_ckpt = False
expected_steps = data["train"].dataloader.num_batches
if steps_done_epoch < (1 - args.data_tolerate_error_p) * expected_steps and not done_training:
failed_ckpt = True
num_ckpt_too_few_tokens += 1
if is_master(args):
logging.warning(
f"Epoch {epoch}, tokens seen: {steps_done_epoch * args.global_batch_size * args.seq_len}, tokens expected: {expected_steps * args.global_batch_size * args.seq_len}, ratio: {steps_done_epoch / expected_steps}"
)

epoch = epoch + 1
evaluation_metrics = []
if "val_list" in data and (epoch % args.val_frequency == 0 or done_training):
Expand All @@ -846,6 +879,29 @@ def main(args):
logging.error(traceback.format_exc())
logging.warning("evaluation failed! continuing to save_checkpoint")

if is_master(args):
end_of_epoch_log = {
"epoch": epoch,
"tokens": (global_step + 1) * args.global_batch_size * args.seq_len,
"checkpoints_too_few_tokens": num_ckpt_too_few_tokens,
"percentage_of_data_seen": steps_done_epoch / expected_steps,
}

if args.dataset_manifest is not None:
for i in range(len(next_shard_per_source)):
end_of_epoch_log[f"next_shard_{i}"] = next_shard_per_source[i]
end_of_epoch_log[f"dataset_pass_{i}"] = next_shard_per_source[i] // len(
get_metadata_file(args.dataset_manifest[i])
)

for name, val in end_of_epoch_log.items():
name = "train/" + name
if writer is not None:
writer.add_scalar(name, val, global_step)
if args.wandb:
assert wandb is not None, "Please install wandb."
wandb.log({name: val, "step": global_step, "tokens": end_of_epoch_log["tokens"]})

# Saving checkpoints.
save_checkpoint(
args,
Expand All @@ -856,12 +912,20 @@ def main(args):
evaluation_metrics,
step=global_step,
is_final_checkpoint=done_training,
percentage_of_data_seen=1.0 * steps_done_epoch / expected_steps,
next_shard_per_source=next_shard_per_source if args.dataset_manifest is not None else None,
samples_seen=samples_seen if args.dataset_manifest is not None else None,
shard_shuffle_seed=args.shard_shuffle_seed,
train_data_string=train_data_string_per_source if args.dataset_manifest is not None else None,
averagers=averagers,
failed=failed_ckpt,
)

if num_ckpt_too_few_tokens > args.data_tolerate_num_ckpts:
raise RuntimeError(
f"{num_ckpt_too_few_tokens} checkpoints happened where the number of tokens seen was {1 - args.data_tolerate_error_p} of expected. This is likely due to transient errors e.g. reading from S3."
)

if done_training:
if is_master(args):
logging.info("Model has seen the desired number of tokens. Ending training.")
Expand All @@ -874,7 +938,8 @@ def main(args):
if remote_sync_process is not None:
logging.info("Final remote sync.")
terminate_sync_process(remote_sync_process)
result = remote_sync(
result = remote_sync_with_expon_backoff(
args.remote_sync_frequency,
os.path.join(args.logs, args.name),
os.path.join(args.remote_sync, args.name),
args.remote_sync_protocol,
Expand Down
17 changes: 16 additions & 1 deletion open_lm/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,10 @@ def check_args(args):
# Make sure that val batch size is set to micro batch size
args.global_val_batch_size = args.global_batch_size // args.accum_freq

assert (
args.train_data is None or args.dataset_manifest is None
), "--dataset-manifest and --train-data cannot both be set"

# custom_attn checks
if args.attn_name == "custom_attn":
assert (
Expand Down Expand Up @@ -771,7 +775,18 @@ def parse_args(args):
default=0,
help="Whether to log the average model training loss. if not 0, it will log the average loss over the specified number of steps.",
)

parser.add_argument(
"--data-tolerate-error-p",
type=float,
default=0.09, # Roughly the number required to not repeat more than 10% of data.
help="This is the percentage of expected tokens above which the checkpoint is considered failed because of not having seen enough data.",
)
parser.add_argument(
"--data-tolerate-num-ckpts",
type=int,
default=0,
help="This is the maximum number of failed checkpoints (due to not having seen enough tokens) that are allowed",
)
parser.add_argument(
"--use-fp8",
action="store_true",
Expand Down
Loading

0 comments on commit 936cd9a

Please sign in to comment.