Skip to content

Commit

Permalink
Add dask query-planning support (NVIDIA#139)
Browse files Browse the repository at this point in the history
* start adding dask-expr support

Signed-off-by: rjzamora <[email protected]>

* add query_planning_enabled util

Signed-off-by: rjzamora <[email protected]>

* add global keyword

Signed-off-by: rjzamora <[email protected]>

* Forgot to remove top level query-planning check

Signed-off-by: rjzamora <[email protected]>

* fix other shuffle-arg problems that don't 'work' with dask-expr

Signed-off-by: rjzamora <[email protected]>

* remove name arg usage for now

Signed-off-by: rjzamora <[email protected]>

* fix bugs

Signed-off-by: rjzamora <[email protected]>

---------

Signed-off-by: rjzamora <[email protected]>
  • Loading branch information
rjzamora authored and VibhuJawa committed Oct 1, 2024
1 parent eb4e997 commit 0f9fe20
Show file tree
Hide file tree
Showing 10 changed files with 74 additions and 56 deletions.
1 change: 0 additions & 1 deletion examples/slurm/start-slurm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ export CUDF_SPILL="1"
export RMM_SCHEDULER_POOL_SIZE="1GB"
export RMM_WORKER_POOL_SIZE="72GiB"
export LIBCUDF_CUFILE_POLICY=OFF
export DASK_DATAFRAME__QUERY_PLANNING=False


# =================================================================
Expand Down
19 changes: 0 additions & 19 deletions nemo_curator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import sys

import dask

# Disable query planning if possible
# https://github.com/NVIDIA/NeMo-Curator/issues/73
if dask.config.get("dataframe.query-planning") is True or "dask_expr" in sys.modules:
raise NotImplementedError(
"""
NeMo Curator does not support query planning yet.
Please disable query planning before importing
`dask.dataframe` or `dask_cudf`. This can be done via:
`export DASK_DATAFRAME__QUERY_PLANNING=False`, or
importing `dask.dataframe/dask_cudf` after importing
`nemo_curator`.
"""
)
else:
dask.config.set({"dataframe.query-planning": False})


from .modules import *
from .services import (
AsyncLLMClient,
Expand Down
18 changes: 18 additions & 0 deletions nemo_curator/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import sys

import dask
from packaging.version import parse as parseVersion

Expand All @@ -25,3 +27,19 @@
DASK_SHUFFLE_METHOD_ARG = _dask_version > parseVersion("2024.1.0")
DASK_P2P_ERROR = _dask_version < parseVersion("2023.10.0")
DASK_SHUFFLE_CAST_DTYPE = _dask_version > parseVersion("2023.12.0")

# Query-planning check (and cache)
_DASK_QUERY_PLANNING_ENABLED = None


def query_planning_enabled():
global _DASK_QUERY_PLANNING_ENABLED

if _DASK_QUERY_PLANNING_ENABLED is None:
if _dask_version > parseVersion("2024.6.0"):
import dask.dataframe as dd

_DASK_QUERY_PLANNING_ENABLED = dd.DASK_EXPR_ENABLED
else:
_DASK_QUERY_PLANNING_ENABLED = "dask_expr" in sys.modules
return _DASK_QUERY_PLANNING_ENABLED
1 change: 0 additions & 1 deletion nemo_curator/datasets/doc_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,6 @@ def from_pandas(
npartitions=npartitions,
chunksize=chunksize,
sort=sort,
name=name,
)
)

Expand Down
2 changes: 1 addition & 1 deletion nemo_curator/modules/exact_dedup.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _exact_dup_ids(self, df: dd.DataFrame):
Get the id's for text/documents that are exact duplicates
Parameters
----------
df: dask.dataframe.core.DataFrame
df: dask.dataframe.DataFrame
A dataframe with the following requirements:
* A column where each row is the text from one document
* A unique ID column for each document
Expand Down
12 changes: 6 additions & 6 deletions nemo_curator/modules/fuzzy_dedup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import pyarrow as pa
from cugraph import MultiGraph
from dask import dataframe as dd
from dask.dataframe.shuffle import shuffle as dd_shuffle
from dask.utils import M
from tqdm import tqdm

Expand Down Expand Up @@ -919,8 +918,7 @@ def map_buckets_with_anchors(
transform_divisions=False,
align_dataframes=False,
)
ddf_anchor_docs_with_bk = dd_shuffle(
ddf_anchor_docs_with_bk,
ddf_anchor_docs_with_bk = ddf_anchor_docs_with_bk.shuffle(
self.id_fields,
ignore_index=True,
shuffle_method=shuffle_type,
Expand Down Expand Up @@ -1496,8 +1494,7 @@ def _write_dedup_encoded_jaccard_pair(self, encoded_jaccard_pair_path):
align_dataframes=False,
)

ddf = dd_shuffle(
ddf,
ddf = ddf.shuffle(
[self.left_id, self.right_id],
ignore_index=True,
shuffle_method="tasks",
Expand Down Expand Up @@ -1545,7 +1542,10 @@ def _write_dedup_parsed_id(self):
unique_docs = ddf.map_partitions(
ConnectedComponents._get_unique_ids_per_partition, id_columns=id_columns
)
unique_docs = unique_docs.drop_duplicates(split_out=ddf.npartitions // 4)
unique_docs = unique_docs.drop_duplicates(
# Dask does not guard against split_out=0
split_out=max(ddf.npartitions // 4, 1)
)
unique_docs["uid"] = np.uint64(1)
unique_docs["uid"] = unique_docs["uid"].cumsum()
unique_docs["uid"] = unique_docs["uid"] - 1
Expand Down
14 changes: 11 additions & 3 deletions nemo_curator/utils/fuzzy_dedup_utils/merge_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,10 @@
import numpy as np
import pandas as pd
from dask.base import tokenize
from dask.dataframe.core import new_dd_object
from dask.dataframe.shuffle import partitioning_index
from dask.highlevelgraph import HighLevelGraph
from dask.utils import M

from nemo_curator._compat import DASK_SHUFFLE_CAST_DTYPE
from nemo_curator._compat import DASK_SHUFFLE_CAST_DTYPE, query_planning_enabled


def _split_part(part, nsplits):
Expand All @@ -36,6 +34,16 @@ def _split_part(part, nsplits):


def text_bytes_aware_merge(text_df, right_df, broadcast=True, *, on):

if query_planning_enabled():
raise NotImplementedError(
"The text_bytes_aware_merge function is not supported when "
"query-planning is enabled."
)

from dask.dataframe.core import new_dd_object
from dask.highlevelgraph import HighLevelGraph

if not isinstance(on, list):
on = [on]

Expand Down
15 changes: 12 additions & 3 deletions nemo_curator/utils/fuzzy_dedup_utils/output_map_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import numba
import numpy as np

from nemo_curator._compat import DASK_SHUFFLE_METHOD_ARG
from nemo_curator._compat import DASK_SHUFFLE_METHOD_ARG, query_planning_enabled


def get_agg_text_bytes_df(
Expand All @@ -32,11 +32,20 @@ def get_agg_text_bytes_df(
"""
Groupby bucket and calculate total bytes for a bucket.
"""
shuffle_arg = "shuffle_method" if DASK_SHUFFLE_METHOD_ARG else "shuffle"
if query_planning_enabled():
# `shuffle_method: bool` doesn't really make sense
# when query-planning is enabled, because dask-expr
# will ALWAYS use a shuffle-based reduction when
# `split_out>1`
shuffle_arg = {}
else:
shuffle_arg = {
"shuffle_method" if DASK_SHUFFLE_METHOD_ARG else "shuffle": shuffle
}
agg_df = (
df[[agg_column, bytes_column]]
.groupby([agg_column])
.agg({bytes_column: "sum"}, split_out=n_partitions, **{shuffle_arg: shuffle})
.agg({bytes_column: "sum"}, split_out=n_partitions, **shuffle_arg)
)
agg_df = agg_df.reset_index(drop=False)
# Doing a per partition sort
Expand Down
28 changes: 26 additions & 2 deletions nemo_curator/utils/fuzzy_dedup_utils/shuffle_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@
import dask_cuda
import numpy as np
from dask import config
from dask.dataframe.shuffle import rearrange_by_column
from dask_cuda.explicit_comms.dataframe.shuffle import shuffle as explicit_comms_shuffle
from packaging.version import Version

from nemo_curator._compat import query_planning_enabled
from nemo_curator.utils.fuzzy_dedup_utils.output_map_utils import (
build_partition,
get_agg_text_bytes_df,
Expand Down Expand Up @@ -53,6 +52,10 @@ def rearange_by_column_direct(
):
# Execute a "direct" shuffle operation without staging
if config.get("explicit-comms", excomms_default):
from dask_cuda.explicit_comms.dataframe.shuffle import (
shuffle as explicit_comms_shuffle,
)

# Use explicit comms unless the user has
# disabled it with the dask config system,
# or we are using an older version of dask-cuda
Expand All @@ -62,7 +65,28 @@ def rearange_by_column_direct(
npartitions=npartitions,
ignore_index=ignore_index,
)

elif query_planning_enabled():
from dask_expr._collection import new_collection
from dask_expr._shuffle import RearrangeByColumn

# Use the internal dask-expr API
return new_collection(
RearrangeByColumn(
frame=df.expr,
partitioning_index=col,
npartitions_out=npartitions,
ignore_index=ignore_index,
method="tasks",
# Prevent staged shuffling by setting max_branch
# to the number of input partitions + 1
options={"max_branch": npartitions + 1},
)
)

else:
from dask.dataframe.shuffle import rearrange_by_column

return rearrange_by_column(
df,
col=col,
Expand Down
20 changes: 0 additions & 20 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys

import dask

# Disable query planning if possible
# https://github.com/NVIDIA/NeMo-Curator/issues/73
if dask.config.get("dataframe.query-planning") is True or "dask_expr" in sys.modules:
raise NotImplementedError(
"""
NeMo Curator does not support query planning yet.
Please disable query planning before importing
`dask.dataframe` or `dask_cudf`. This can be done via:
`export DASK_DATAFRAME__QUERY_PLANNING=False`, or
importing `dask.dataframe/dask_cudf` after importing
`nemo_curator`.
"""
)
else:
dask.config.set({"dataframe.query-planning": False})

0 comments on commit 0f9fe20

Please sign in to comment.