From 1f6a71f267f9916579c04b07cfb4c4178f9cfb0e Mon Sep 17 00:00:00 2001 From: Jeffrey Li Date: Tue, 2 Jan 2024 00:03:04 -0800 Subject: [PATCH 1/2] add in max buffer size --- .../datapreprocess/ray/tokenize_shuffle.py | 53 +++++++++++-------- 1 file changed, 32 insertions(+), 21 deletions(-) diff --git a/open_lm/datapreprocess/ray/tokenize_shuffle.py b/open_lm/datapreprocess/ray/tokenize_shuffle.py index f64eb8d6..1cd625b9 100644 --- a/open_lm/datapreprocess/ray/tokenize_shuffle.py +++ b/open_lm/datapreprocess/ray/tokenize_shuffle.py @@ -256,6 +256,7 @@ def preprocess( do_sample: bool = False, sources: enum.Enum = None, source_counter: GlobalCounter = None, + max_buffer_seqs: int = 1000, ): tokenizer_fn, vocab_size = tokenizer rng = random.Random(hash(key) + seed) @@ -275,29 +276,35 @@ def preprocess( for string in pbar: tokens = tokenizer_fn(string) tokens.append(EOT) - buffer += tokens - while len(buffer) >= seqlen: - if do_sample: - local_sample_freq = sample_freq - # This code does the following - # yield a int(sample_freq) copies of buffer[:seqlen] - # then yield 1 more sample with Pr[sample_freq - int(sample_freq)] - # in expectation we will yield sample_freq copies of buffer[:seqlen] - while local_sample_freq > 1: + while len(tokens) > 0: + # Add tokens to the buffer while controlling buffer, speeds up slicing for large documents + idx = min(seqlen*max_buffer_seqs-len(buffer), len(tokens)) + buffer += tokens[:idx] + tokens = tokens[idx:] + + while len(buffer) >= seqlen: + if do_sample: + local_sample_freq = sample_freq + # This code does the following + # yield a int(sample_freq) copies of buffer[:seqlen] + # then yield 1 more sample with Pr[sample_freq - int(sample_freq)] + # in expectation we will yield sample_freq copies of buffer[:seqlen] + while local_sample_freq > 1: + if source_counter is not None: + ray.get(source_counter.increment_token_count.remote(seqlen)) + yield buffer[:seqlen] + local_sample_freq -= 1 + if rng.random() < local_sample_freq: + if source_counter is not None: + ray.get(source_counter.increment_token_count.remote(seqlen)) + yield buffer[:seqlen] + buffer = buffer[seqlen:] + else: if source_counter is not None: ray.get(source_counter.increment_token_count.remote(seqlen)) yield buffer[:seqlen] - local_sample_freq -= 1 - if rng.random() < local_sample_freq: - if source_counter is not None: - ray.get(source_counter.increment_token_count.remote(seqlen)) - yield buffer[:seqlen] - buffer = buffer[seqlen:] - else: - if source_counter is not None: - ray.get(source_counter.increment_token_count.remote(seqlen)) - yield buffer[:seqlen] - buffer = buffer[seqlen:] + buffer = buffer[seqlen:] + if len(buffer) > 0: if source_counter is not None: ray.get(source_counter.increment_token_count.remote(len(buffer))) @@ -308,7 +315,7 @@ def preprocess( return [] -def process_keys(data, tokenizer, seqlen, seed, content_key, do_sample, sources=None, source_counters=None): +def process_keys(data, tokenizer, seqlen, seed, content_key, do_sample, sources=None, source_counters=None, max_buffer_seqs=1000): path = data["path"] if path.startswith("s3"): @@ -337,6 +344,7 @@ def process_keys(data, tokenizer, seqlen, seed, content_key, do_sample, sources= do_sample=do_sample, sources=sources, source_counter=source_counter, + max_buffer_seqs=max_buffer_seqs, ) # Ensure that all operations on the file handle are done within this block @@ -570,6 +578,8 @@ def main(args): ) # default is localhost; for slurm jobs do 0.0.0.0 parser.add_argument("--suffixes", nargs="+", default=[".json", ".jsonl", ".zst", ".zstd", ".tar", ".gz"]) parser.add_argument("--presort", action="store_true") + parser.add_argument("--max_buffer_seqs", type=int, default=1000) + args = parser.parse_args(args) if args.do_sample: @@ -655,6 +665,7 @@ def main(args): do_sample=args.do_sample, sources=Sources, source_counters=source_counters, + max_buffer_seqs=args.max_buffer_seqs, ) ) ds = ds.map(add_hash) From f9dd1aeff644d2b45dc31a310eb8e1ddc3db30b3 Mon Sep 17 00:00:00 2001 From: Jeffrey Date: Fri, 10 May 2024 01:14:52 -0700 Subject: [PATCH 2/2] linting --- open_lm/datapreprocess/ray/tokenize_shuffle.py | 11 ++++++----- open_lm/utils/convert_llama.py | 1 - tests/test_dataset_no_resample.py | 1 - tests/test_file_utils.py | 1 - tests/test_training_tokens.py | 2 +- 5 files changed, 7 insertions(+), 9 deletions(-) diff --git a/open_lm/datapreprocess/ray/tokenize_shuffle.py b/open_lm/datapreprocess/ray/tokenize_shuffle.py index 1cd625b9..87a5f2ed 100644 --- a/open_lm/datapreprocess/ray/tokenize_shuffle.py +++ b/open_lm/datapreprocess/ray/tokenize_shuffle.py @@ -277,9 +277,9 @@ def preprocess( tokens = tokenizer_fn(string) tokens.append(EOT) while len(tokens) > 0: - # Add tokens to the buffer while controlling buffer, speeds up slicing for large documents - idx = min(seqlen*max_buffer_seqs-len(buffer), len(tokens)) - buffer += tokens[:idx] + # Add tokens to the buffer while controlling buffer, speeds up slicing for large documents + idx = min(seqlen * max_buffer_seqs - len(buffer), len(tokens)) + buffer += tokens[:idx] tokens = tokens[idx:] while len(buffer) >= seqlen: @@ -315,7 +315,9 @@ def preprocess( return [] -def process_keys(data, tokenizer, seqlen, seed, content_key, do_sample, sources=None, source_counters=None, max_buffer_seqs=1000): +def process_keys( + data, tokenizer, seqlen, seed, content_key, do_sample, sources=None, source_counters=None, max_buffer_seqs=1000 +): path = data["path"] if path.startswith("s3"): @@ -580,7 +582,6 @@ def main(args): parser.add_argument("--presort", action="store_true") parser.add_argument("--max_buffer_seqs", type=int, default=1000) - args = parser.parse_args(args) if args.do_sample: Sources, SAMPLING_FREQUENCIES = load_from_yaml(args.default_dataset_yaml) diff --git a/open_lm/utils/convert_llama.py b/open_lm/utils/convert_llama.py index 22240ab0..3c879116 100644 --- a/open_lm/utils/convert_llama.py +++ b/open_lm/utils/convert_llama.py @@ -3,7 +3,6 @@ Usage: `python convert_llama_to_openlm.py ` """ - import torch import sys diff --git a/tests/test_dataset_no_resample.py b/tests/test_dataset_no_resample.py index a1dbcf61..bcb5aefb 100644 --- a/tests/test_dataset_no_resample.py +++ b/tests/test_dataset_no_resample.py @@ -11,7 +11,6 @@ """ - import pytest import random import os diff --git a/tests/test_file_utils.py b/tests/test_file_utils.py index a6cee390..af61ff0a 100644 --- a/tests/test_file_utils.py +++ b/tests/test_file_utils.py @@ -4,7 +4,6 @@ """ - from open_lm.file_utils import get_string_for_epoch import pytest diff --git a/tests/test_training_tokens.py b/tests/test_training_tokens.py index 3640597f..16ce4b9d 100644 --- a/tests/test_training_tokens.py +++ b/tests/test_training_tokens.py @@ -17,7 +17,7 @@ (100, 2, 1000, 4, [20, 40]), # Easy case. (100, 2, 1200, 4, [20, 40, 48]), # End before consuming all in a shard. (100, 2, 1500, 4, [20, 40, 54, 60]), # One of the shards here is smaller. 54 instead of 56 because of workers. - (85, 2, 1000, 4, [22, 44, 47]) # Batch weirdness, total_steps = 1000 * 4 // 85 = 47, + (85, 2, 1000, 4, [22, 44, 47]), # Batch weirdness, total_steps = 1000 * 4 // 85 = 47, # steps_epoch = 2000 // (85 * 2) * 2 = 22 ], )