diff --git a/dask_cuda/benchmarks/custom/parquet.py b/dask_cuda/benchmarks/custom/parquet.py index fa5075bc..47bf69ef 100644 --- a/dask_cuda/benchmarks/custom/parquet.py +++ b/dask_cuda/benchmarks/custom/parquet.py @@ -67,7 +67,12 @@ def tables_to_frame(tables): ) -def read_parquet_fragments(fragments, columns=None, filters=None): +def read_parquet_fragments( + fragments, + columns=None, + filters=None, + fragment_parallelism=None, +): kwargs = {"columns": columns, "filters": filters} if not isinstance(fragments, list): @@ -89,7 +94,11 @@ def read_parquet_fragments(fragments, columns=None, filters=None): return dask.threaded.get(dsk, chunk_name) if not hasattr(worker, "_rapids_executor"): - num_threads = len(os.sched_getaffinity(0)) + fragment_parallelism = fragment_parallelism or 8 + num_threads = min( + fragment_parallelism, + len(os.sched_getaffinity(0)), + ) worker._rapids_executor = ThreadPoolExecutor(num_threads) with dask.config.set(pool=worker._rapids_executor): return dask.threaded.get(dsk, chunk_name) @@ -129,7 +138,13 @@ def aggregate_fragments(fragments, blocksize): return [fragments[i : i + stride] for i in range(0, len(fragments), stride)] -def read_parquet(urlpath, columns=None, filters=None, blocksize="256MB", **kwargs): +def read_parquet( + urlpath, + columns=None, + filters=None, + blocksize="256MB", + fragment_parallelism=None, +): # Use pyarrow dataset API to get fragments and meta ds = dataset.dataset(urlpath, format="parquet") @@ -148,6 +163,7 @@ def read_parquet(urlpath, columns=None, filters=None, blocksize="256MB", **kwarg fragments, columns=columns, filters=filters, + fragment_parallelism=fragment_parallelism, meta=meta, enforce_metadata=False, )