Skip to content

Commit

Permalink
Guard against indices with 0 samples
Browse files Browse the repository at this point in the history
  • Loading branch information
undfined committed Jan 23, 2025
1 parent 9818232 commit 4934c7c
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 3 deletions.
10 changes: 7 additions & 3 deletions src/olmo_core/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,10 +438,14 @@ def segment_documents_into_instances(
rng = get_rng(seed)
indices = rng.choice(indices.reshape(-1, 2), size=max_instances).reshape(-1)

with memmap_to_write(target, dtype=indices_dtype, shape=(indices.size,)) as indices_mmap:
indices_mmap[:] = indices
# NOTE: It's possible to sample 0 instances from small source files. Rather than try to write this empty array we conditionally skip if indices_out is 0.
indices_out = len(indices) // 2

return total_og_docs, len(indices) // 2
if indices_out > 0:
with memmap_to_write(target, dtype=indices_dtype, shape=(indices.size,)) as indices_mmap:
indices_mmap[:] = indices

return total_og_docs, indices_out


def run_worker_func(func, *args, **kwargs):
Expand Down
32 changes: 32 additions & 0 deletions src/test/data/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,38 @@ def test_segment_documents_into_instances(tmp_path):
assert all([r[1] == 2 for r in results])


@pytest.mark.limit_memory("245 KB")
def test_segment_documents_into_instances_zero_sample(tmp_path):
data = [1, 2, 3, 4, 0, 5, 6, 7, 8, 0, 1, 2, 3, 4, 5, 0] * 1000
data_path = tmp_path / "data.npy"
max_sequence_length = 4
mmap = np.memmap(data_path, mode="w+", dtype=np.uint16, shape=(len(data),))
indices_path = tmp_path / "indices.npy"
mmap[:] = data
mmap.flush()

eos = 0
dtype = np.uint16
# Explicitly set sample to zero
sample = (0, 42)

results = []
# Should not raise ValueError, but should return zero instances for each iteration
for _ in range(10):
results.append(
segment_documents_into_instances(
path=data_path,
target=indices_path,
max_sequence_length=max_sequence_length,
eos_token_id=eos,
dtype=dtype,
sample=sample,
)
)

assert all([r[1] == 0 for r in results])


def test_iter_document_indices(tmp_path):
data = [1, 2, 3, 4, 5, 6, 7, 8, 0, 1, 2, 3, 4, 5, 0]
data_path = tmp_path / "data.npy"
Expand Down

0 comments on commit 4934c7c

Please sign in to comment.