Skip to content

Commit

Permalink
add in max buffer size
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffreywpli committed May 9, 2024
1 parent 45879d2 commit 1f6a71f
Showing 1 changed file with 32 additions and 21 deletions.
53 changes: 32 additions & 21 deletions open_lm/datapreprocess/ray/tokenize_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def preprocess(
do_sample: bool = False,
sources: enum.Enum = None,
source_counter: GlobalCounter = None,
max_buffer_seqs: int = 1000,
):
tokenizer_fn, vocab_size = tokenizer
rng = random.Random(hash(key) + seed)
Expand All @@ -275,29 +276,35 @@ def preprocess(
for string in pbar:
tokens = tokenizer_fn(string)
tokens.append(EOT)
buffer += tokens
while len(buffer) >= seqlen:
if do_sample:
local_sample_freq = sample_freq
# This code does the following
# yield a int(sample_freq) copies of buffer[:seqlen]
# then yield 1 more sample with Pr[sample_freq - int(sample_freq)]
# in expectation we will yield sample_freq copies of buffer[:seqlen]
while local_sample_freq > 1:
while len(tokens) > 0:
# Add tokens to the buffer while controlling buffer, speeds up slicing for large documents
idx = min(seqlen*max_buffer_seqs-len(buffer), len(tokens))
buffer += tokens[:idx]
tokens = tokens[idx:]

while len(buffer) >= seqlen:
if do_sample:
local_sample_freq = sample_freq
# This code does the following
# yield a int(sample_freq) copies of buffer[:seqlen]
# then yield 1 more sample with Pr[sample_freq - int(sample_freq)]
# in expectation we will yield sample_freq copies of buffer[:seqlen]
while local_sample_freq > 1:
if source_counter is not None:
ray.get(source_counter.increment_token_count.remote(seqlen))
yield buffer[:seqlen]
local_sample_freq -= 1
if rng.random() < local_sample_freq:
if source_counter is not None:
ray.get(source_counter.increment_token_count.remote(seqlen))
yield buffer[:seqlen]
buffer = buffer[seqlen:]
else:
if source_counter is not None:
ray.get(source_counter.increment_token_count.remote(seqlen))
yield buffer[:seqlen]
local_sample_freq -= 1
if rng.random() < local_sample_freq:
if source_counter is not None:
ray.get(source_counter.increment_token_count.remote(seqlen))
yield buffer[:seqlen]
buffer = buffer[seqlen:]
else:
if source_counter is not None:
ray.get(source_counter.increment_token_count.remote(seqlen))
yield buffer[:seqlen]
buffer = buffer[seqlen:]
buffer = buffer[seqlen:]

if len(buffer) > 0:
if source_counter is not None:
ray.get(source_counter.increment_token_count.remote(len(buffer)))
Expand All @@ -308,7 +315,7 @@ def preprocess(
return []


def process_keys(data, tokenizer, seqlen, seed, content_key, do_sample, sources=None, source_counters=None):
def process_keys(data, tokenizer, seqlen, seed, content_key, do_sample, sources=None, source_counters=None, max_buffer_seqs=1000):
path = data["path"]

if path.startswith("s3"):
Expand Down Expand Up @@ -337,6 +344,7 @@ def process_keys(data, tokenizer, seqlen, seed, content_key, do_sample, sources=
do_sample=do_sample,
sources=sources,
source_counter=source_counter,
max_buffer_seqs=max_buffer_seqs,
)

# Ensure that all operations on the file handle are done within this block
Expand Down Expand Up @@ -570,6 +578,8 @@ def main(args):
) # default is localhost; for slurm jobs do 0.0.0.0
parser.add_argument("--suffixes", nargs="+", default=[".json", ".jsonl", ".zst", ".zstd", ".tar", ".gz"])
parser.add_argument("--presort", action="store_true")
parser.add_argument("--max_buffer_seqs", type=int, default=1000)


args = parser.parse_args(args)
if args.do_sample:
Expand Down Expand Up @@ -655,6 +665,7 @@ def main(args):
do_sample=args.do_sample,
sources=Sources,
source_counters=source_counters,
max_buffer_seqs=args.max_buffer_seqs,
)
)
ds = ds.map(add_hash)
Expand Down

0 comments on commit 1f6a71f

Please sign in to comment.