Skip to content

Commit

Permalink
Less contention in tokenization (#893)
Browse files Browse the repository at this point in the history
Hopefully avoiding the $$$ we've seen tokenizing
  • Loading branch information
dlwh authored Feb 26, 2025
1 parent cc5f730 commit d07b0ca
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 39 deletions.
2 changes: 0 additions & 2 deletions src/levanter/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,9 +307,7 @@ async def _dataset_get_available_batch_number(self, target_max_batch_number: int
"""
if self.dl.data_store.is_finite():
next_end = self.dl.scheduler.global_data_offset_by_step(target_max_batch_number)
logger.info(f"waiting for {next_end}")
available_len = await self.dl.data_store.wait_until_len_at_least(next_end)
logger.info(f"for {next_end} got {available_len}")

at_the_end = available_len < next_end

Expand Down
135 changes: 101 additions & 34 deletions src/levanter/store/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -1081,6 +1081,9 @@ def _copy_temp_caches_to_final_cache(
total_rows_from_caches = overall_ledger.total_num_rows
copy_refs: dict[str, ray.ObjectRef] = {}

metadata_copier = _MetadataCopier.options(name=f"metadata_copier::{cache_dir}").remote(parent)
copy_metadata_refs: dict[str, ray.ObjectRef] = {}

parent._report_copy_progress.remote(
_ProgressReport(new_shards=len(overall_ledger.finished_shards), new_rows=overall_ledger.total_num_rows)
)
Expand Down Expand Up @@ -1129,6 +1132,14 @@ def _copy_temp_caches_to_final_cache(
total_rows_from_caches,
parent,
)
copy_metadata_refs[group] = metadata_copier.copy_metadata.remote(
cache_dir,
group_cache_paths[group],
processor_ref,
data_offset_tree,
total_rows_from_caches,
)

this_rows = this_ledger.total_num_rows
total_rows_from_caches += this_rows

Expand All @@ -1145,6 +1156,7 @@ def _copy_temp_caches_to_final_cache(
num_available_rows = overall_ledger.total_num_rows
for group, ref in copy_refs.items():
ray.get(ref) # block on data copy
ray.get(copy_metadata_refs[group]) # block on metadata copy

group_ledger = group_ledgers[group]
num_available_rows += group_ledger.total_num_rows
Expand Down Expand Up @@ -1195,6 +1207,39 @@ def _copy_cache_data(dest_path, source_path, processor, data_offset_tree, rows_s
asyncio.run(_extend_cache_with_other_cache(dest_path, source_path, processor, data_offset_tree, rows_so_far))


@ray.remote(
num_cpus=2,
memory=1 * 1024 * 1024 * 1024,
runtime_env=RuntimeEnv(env_vars={"JAX_PLATFORMS": "cpu"}),
)
class _MetadataCopier:
"""Copies the metadata from one cache to another. We use an actor because we want to impose the impliacit
actor mutex lock on the metadata file to prevent concurrent writes. We have found that using ts Transactions here
results in way too many retries, resulting in $$$$. If we prevent concurrent writes, we can avoid this."""

# DO NOT ADD ASYNC METHODS TO THIS CLASS. It will remove the implicit lock and cause concurrent writes.
def __init__(self, parent):
self.parent = parent

def copy_metadata(self, dest_path, source_path, processor, data_offset_tree, rows_so_far):
"""
Copies the data from one cache to another, appending it to the end of the destination cache.
Once the copy is done and the last_ref is set, the data is "unlocked" in the destination cache by updating the
offsets[0] of the destination cache to the total number of rows in the cache.
Args:
dest_path: The path to the destination cache.
source_path: The path to the source cache.
processor: The processor used to create the cache.
data_offset_tree: The data offset tree for the destination cache.
rows_so_far: The total number of rows in the destination cache before this copy.
"""
with log_failures_to(self.parent):
asyncio.run(
_extend_cache_metadata_with_other(dest_path, source_path, processor, data_offset_tree, rows_so_far)
)


async def _extend_cache_with_other_cache(
dest_path: str, source_path: str, processor: BatchProcessor, data_offset_tree: PyTree[int], row_offset
) -> int:
Expand All @@ -1218,13 +1263,58 @@ async def _copy_one_array(dest_array: JaggedArrayStore, source_array: JaggedArra
# TODO: it'd be good if we just didn't expose the full data array (but only the used part)
data_size = source_array.data_size
data = source_array.data[0:data_size]
futures: list[ts.Future] = []

# To prevent OOM, copy in smaller batches
MAX_ELEMS = 1024 * 1024 * 1024
f = await _copy_in_batches(dest_array.data, data_offset, data, data_size, MAX_ELEMS)
if f is not None:
futures.append(f)
await _copy_in_batches(dest_array.data, data_offset, data, data_size, MAX_ELEMS)

futures = jax.tree.map(_copy_one_array, dest.tree, source.tree, data_offset_tree)

await asyncio.gather(*jax.tree.leaves(futures))
logger.info(f"Finished copying data from {source_path} to {dest_path}.")

return source_num_rows
except Exception as e:
logger.exception(f"Failed to copy data from {source_path} to {dest_path}: {e}")
raise


async def _copy_in_batches(dest_array, dest_offset, src_array, src_len, elems_per_batch):
"""
Copies the data from one array to another in batches.
"""
last_future: ts.Future | None = None
start = 0
out_start = dest_offset
while start < src_len:
if last_future is not None:
await last_future
async with ts.Transaction() as txn:
num_to_copy = min(elems_per_batch, src_len - start)
end = start + num_to_copy
out_end = out_start + num_to_copy

last_future = dest_array.with_transaction(txn)[out_start:out_end].write(src_array[start:end])
start += num_to_copy
out_start += num_to_copy

if last_future is not None:
await last_future


async def _extend_cache_metadata_with_other(
dest_path: str, source_path: str, processor: BatchProcessor, data_offset_tree: PyTree[int], row_offset
) -> int:
"""Copies just the offsets and shapes (if present)"""
try:
logger.info(f"Copying metadata from {source_path} to {dest_path}.")
dest = TreeStore.open(processor.output_exemplar, dest_path, mode="a")
source = TreeStore.open(processor.output_exemplar, source_path, mode="r", cache_metadata=True)

source_num_rows = await source.async_len()

async def _copy_one_array(dest_array: JaggedArrayStore, source_array: JaggedArrayStore, data_offset: int):
"""Copies **just the data array** from one shard to the permanent cache at a given offset."""

if source_array.shapes is not None:
source_shapes = source_array.shapes[0:source_num_rows]
Expand All @@ -1233,8 +1323,8 @@ async def _copy_one_array(dest_array: JaggedArrayStore, source_array: JaggedArra
assert dest_shapes is not None
out_end = row_offset + source_num_rows
shape_future = dest_shapes.with_transaction(txn)[row_offset:out_end].write(source_shapes)
futures.append(shape_future)

# the 0th offset is the number of rows so we don't want to copy it into the destination
source_offsets = source_array.offsets[1 : source_num_rows + 1][ts.d[:].translate_to[0]]
source_offsets = _virtual_offset(source_offsets, data_offset)

Expand All @@ -1243,7 +1333,7 @@ async def _copy_one_array(dest_array: JaggedArrayStore, source_array: JaggedArra
try:
async with ts.Transaction() as txn:
dest_offsets = dest_array.offsets
out_end = row_offset + 1 + source_num_rows
out_end = 1 + row_offset + source_num_rows
offset_future = dest_offsets.with_transaction(txn)[row_offset + 1 : out_end].write(
source_offsets
)
Expand All @@ -1257,43 +1347,20 @@ async def _copy_one_array(dest_array: JaggedArrayStore, source_array: JaggedArra
if delay > 120:
raise

futures.append(offset_future)

await asyncio.gather(*futures)
await offset_future
if source_array.shapes is not None:
await shape_future

futures = jax.tree.map(_copy_one_array, dest.tree, source.tree, data_offset_tree)

await asyncio.gather(*jax.tree.leaves(futures))
logger.info(f"Finished copying data from {source_path} to {dest_path}.")

logger.info(f"Finished copying metadata from {source_path} to {dest_path}.")
return source_num_rows
except Exception as e:
logger.exception(f"Failed to copy data from {source_path} to {dest_path}: {e}")
logger.exception(f"Failed to copy metadata from {source_path} to {dest_path}: {e}")
raise


async def _copy_in_batches(dest_array, dest_offset, src_array, src_len, elems_per_batch) -> ts.Future | None:
"""
Copies the data from one array to another in batches.
"""
last_future: ts.Future | None = None
start = 0
out_start = dest_offset
while start < src_len:
if last_future is not None:
await last_future
async with ts.Transaction() as txn:
num_to_copy = min(elems_per_batch, src_len - start)
end = start + num_to_copy
out_end = out_start + num_to_copy

last_future = dest_array.with_transaction(txn)[out_start:out_end].write(src_array[start:end])
start += num_to_copy
out_start += num_to_copy

return last_future


def _virtual_offset(base: ts.TensorStore, offset_amount):
"""
This function creates a new tensorstore that is a virtual offset of another tensorstore.
Expand Down
8 changes: 5 additions & 3 deletions tests/test_newdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,16 @@ async def test_list_async_dataset_appends_and_finalizes_correctly():

@pytest.mark.asyncio
async def test_permutation_dataset_is_at_least_sometimes_permuted():
ok = 0
for seed in range(10):
data = [1, 2, 3, 4]
dataset = ListAsyncDataset(data, is_complete=True)
permuted_dataset = PermutationDataset(dataset, jax.random.PRNGKey(seed))
if await permuted_dataset.get_batch([0, 1, 2, 3]) != [1, 2, 3, 4]:
return
batch = await permuted_dataset.get_batch([0, 1, 2, 3])
if batch != [1, 2, 3, 4]:
ok += 1

pytest.fail("PermutationDataset did not permute the data")
assert ok > 5, "Permutation dataset is not actually permuting"


@pytest.mark.asyncio
Expand Down

0 comments on commit d07b0ca

Please sign in to comment.