-
Notifications
You must be signed in to change notification settings - Fork 143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Extract Python and Dask Executor
classes from Workflow
#1609
Merged
karlhigley
merged 16 commits into
NVIDIA-Merlin:main
from
karlhigley:refactor/decouple-dask
Aug 15, 2022
Merged
Changes from 14 commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
10c51e5
Extract `MerlinPythonExecutor` from `nvt.Workflow`
karlhigley ead2f5b
Extract `MerlinDaskExecutor` from `nvt.Workflow`
karlhigley a72c7be
Clean up `MerlinDaskExecutor`
karlhigley 5101c42
Clean up `MerlinPythonExecutor`
karlhigley c3eddfb
Move `_clear_worker_cache` to `MerlinDaskExecutor`
karlhigley 32512d8
Move `ensure_optimize_dataframe_graph` into `MerlinDaskExecutor`
karlhigley f117717
Clarify `Nodes` vs `Operators` in `Workflow.fit()`
karlhigley bceae0b
Extract the Dask-specific part of `Workflow.fit` to `MerlinDaskExecutor`
karlhigley 9fd498b
Move Dask client into `MerlinDaskExecutor`
karlhigley 28ccd11
Inline `_get_stat_op_nodes` to improve clarity
karlhigley 4f3e941
Clean up `MerlinDaskExecutor.fit()`
karlhigley 64914a5
Merge branch 'main' into refactor/decouple-dask
karlhigley 7ca7c0d
Merge branch 'main' into refactor/decouple-dask
karlhigley 242fc36
Merge branch 'main' into refactor/decouple-dask
karlhigley 9df466c
Merge branch 'main' into refactor/decouple-dask
karlhigley 35f7c15
Merge branch 'main' into refactor/decouple-dask
karlhigley File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,247 @@ | ||
# | ||
# Copyright (c) 2021, NVIDIA CORPORATION. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
import logging | ||
|
||
import dask | ||
import pandas as pd | ||
from dask.core import flatten | ||
|
||
from merlin.core.dispatch import concat_columns, is_list_dtype, list_val_dtype | ||
from merlin.core.utils import ( | ||
ensure_optimize_dataframe_graph, | ||
global_dask_client, | ||
set_client_deprecated, | ||
) | ||
from merlin.dag import ColumnSelector, Node | ||
from merlin.io.worker import clean_worker_cache | ||
|
||
LOG = logging.getLogger("nvtabular") | ||
|
||
|
||
class MerlinPythonExecutor: | ||
def apply(self, df, nodes, output_dtypes=None, additional_columns=None, capture_dtypes=False): | ||
""" | ||
Transforms a single dataframe (possibly a partition of a Dask Dataframe) | ||
by applying the operators from a collection of Nodes | ||
""" | ||
output = None | ||
|
||
for node in nodes: | ||
node_input_cols = get_unique(node.input_schema.column_names) | ||
node_output_cols = get_unique(node.output_schema.column_names) | ||
addl_input_cols = set(node.dependency_columns.names) | ||
|
||
# Build input dataframe | ||
if node.parents_with_dependencies: | ||
# If there are parents, collect their outputs | ||
# to build the current node's input | ||
input_df = None | ||
seen_columns = None | ||
|
||
for parent in node.parents_with_dependencies: | ||
parent_output_cols = get_unique(parent.output_schema.column_names) | ||
parent_df = self.apply(df, [parent], capture_dtypes=capture_dtypes) | ||
if input_df is None or not len(input_df): | ||
input_df = parent_df[parent_output_cols] | ||
seen_columns = set(parent_output_cols) | ||
else: | ||
new_columns = set(parent_output_cols) - seen_columns | ||
input_df = concat_columns([input_df, parent_df[list(new_columns)]]) | ||
seen_columns.update(new_columns) | ||
|
||
# Check for additional input columns that aren't generated by parents | ||
# and fetch them from the root dataframe | ||
unseen_columns = set(node.input_schema.column_names) - seen_columns | ||
addl_input_cols = addl_input_cols.union(unseen_columns) | ||
|
||
# TODO: Find a better way to remove dupes | ||
addl_input_cols = addl_input_cols - set(input_df.columns) | ||
|
||
if addl_input_cols: | ||
input_df = concat_columns([input_df, df[list(addl_input_cols)]]) | ||
else: | ||
# If there are no parents, this is an input node, | ||
# so pull columns directly from root df | ||
input_df = df[node_input_cols + list(addl_input_cols)] | ||
|
||
# Compute the node's output | ||
if node.op: | ||
try: | ||
# use input_columns to ensure correct grouping (subgroups) | ||
selection = node.input_columns.resolve(node.input_schema) | ||
output_df = node.op.transform(selection, input_df) | ||
|
||
# Update or validate output_df dtypes | ||
for col_name, output_col_schema in node.output_schema.column_schemas.items(): | ||
col_series = output_df[col_name] | ||
col_dtype = col_series.dtype | ||
is_list = is_list_dtype(col_series) | ||
|
||
if is_list: | ||
col_dtype = list_val_dtype(col_series) | ||
|
||
output_df_schema = output_col_schema.with_dtype( | ||
col_dtype, is_list=is_list, is_ragged=is_list | ||
) | ||
|
||
if capture_dtypes: | ||
node.output_schema.column_schemas[col_name] = output_df_schema | ||
elif len(output_df): | ||
if output_col_schema.dtype != output_df_schema.dtype: | ||
raise TypeError( | ||
f"Dtype discrepancy detected for column {col_name}: " | ||
f"operator {node.op.label} reported dtype " | ||
f"`{output_col_schema.dtype}` but returned dtype " | ||
f"`{output_df_schema.dtype}`." | ||
) | ||
except Exception: | ||
LOG.exception("Failed to transform operator %s", node.op) | ||
raise | ||
if output_df is None: | ||
raise RuntimeError(f"Operator {node.op} didn't return a value during transform") | ||
else: | ||
output_df = input_df | ||
|
||
# Combine output across node loop iterations | ||
|
||
# dask needs output to be in the same order defined as meta, reorder partitions here | ||
# this also selects columns (handling the case of removing columns from the output using | ||
# "-" overload) | ||
if output is None: | ||
output = output_df[node_output_cols] | ||
else: | ||
output = concat_columns([output, output_df[node_output_cols]]) | ||
|
||
if additional_columns: | ||
output = concat_columns([output, df[get_unique(additional_columns)]]) | ||
|
||
return output | ||
|
||
|
||
class MerlinDaskExecutor: | ||
def __init__(self, client=None): | ||
self._executor = MerlinPythonExecutor() | ||
|
||
# Deprecate `client` | ||
if client is not None: | ||
set_client_deprecated(client, "Workflow") | ||
|
||
def __getstate__(self): | ||
# dask client objects aren't picklable - exclude from saved representation | ||
return {k: v for k, v in self.__dict__.items() if k != "client"} | ||
|
||
def apply(self, ddf, nodes, output_dtypes=None, additional_columns=None, capture_dtypes=False): | ||
""" | ||
Transforms all partitions of a Dask Dataframe by applying the operators | ||
from a collection of Nodes | ||
""" | ||
|
||
self._clear_worker_cache() | ||
|
||
# Check if we are only selecting columns (no transforms). | ||
# If so, we should perform column selection at the ddf level. | ||
# Otherwise, Dask will not push the column selection into the | ||
# IO function. | ||
if not nodes: | ||
return ddf[get_unique(additional_columns)] if additional_columns else ddf | ||
|
||
if isinstance(nodes, Node): | ||
nodes = [nodes] | ||
|
||
columns = list(flatten(wfn.output_columns.names for wfn in nodes)) | ||
columns += additional_columns if additional_columns else [] | ||
|
||
if isinstance(output_dtypes, dict) and isinstance(ddf._meta, pd.DataFrame): | ||
dtypes = output_dtypes | ||
output_dtypes = type(ddf._meta)({k: [] for k in columns}) | ||
for column, dtype in dtypes.items(): | ||
output_dtypes[column] = output_dtypes[column].astype(dtype) | ||
|
||
elif not output_dtypes: | ||
# TODO: constructing meta like this loses dtype information on the ddf | ||
# and sets it all to 'float64'. We should propagate dtype information along | ||
# with column names in the columngroup graph. This currently only | ||
# happesn during intermediate 'fit' transforms, so as long as statoperators | ||
# don't require dtype information on the DDF this doesn't matter all that much | ||
output_dtypes = type(ddf._meta)({k: [] for k in columns}) | ||
|
||
return ensure_optimize_dataframe_graph( | ||
ddf=ddf.map_partitions( | ||
self._executor.apply, | ||
nodes, | ||
additional_columns=additional_columns, | ||
capture_dtypes=capture_dtypes, | ||
meta=output_dtypes, | ||
enforce_metadata=False, | ||
) | ||
) | ||
|
||
def fit(self, ddf, nodes): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 📄 missing nodes in Paramaters docstring here |
||
"""Calculates statistics for a set of nodes on the input dataframe | ||
|
||
Parameters | ||
----------- | ||
ddf: dask.Dataframe | ||
The input dataframe to calculate statistics for. If there is a | ||
train/test split this should be the training dataset only. | ||
""" | ||
stats = [] | ||
for node in nodes: | ||
# Check for additional input columns that aren't generated by parents | ||
addl_input_cols = set() | ||
if node.parents: | ||
upstream_output_cols = sum( | ||
[upstream.output_columns for upstream in node.parents_with_dependencies], | ||
ColumnSelector(), | ||
) | ||
addl_input_cols = set(node.input_columns.names) - set(upstream_output_cols.names) | ||
|
||
# apply transforms necessary for the inputs to the current column group, ignoring | ||
# the transforms from the statop itself | ||
transformed_ddf = self.apply( | ||
ddf, | ||
node.parents_with_dependencies, | ||
additional_columns=addl_input_cols, | ||
capture_dtypes=True, | ||
) | ||
|
||
try: | ||
stats.append(node.op.fit(node.input_columns, transformed_ddf)) | ||
except Exception: | ||
LOG.exception("Failed to fit operator %s", node.op) | ||
raise | ||
|
||
dask_client = global_dask_client() | ||
if dask_client: | ||
results = [r.result() for r in dask_client.compute(stats)] | ||
else: | ||
results = dask.compute(stats, scheduler="synchronous")[0] | ||
|
||
for computed_stats, node in zip(results, nodes): | ||
node.op.fit_finalize(computed_stats) | ||
|
||
def _clear_worker_cache(self): | ||
# Clear worker caches to be "safe" | ||
dask_client = global_dask_client() | ||
if dask_client: | ||
dask_client.run(clean_worker_cache) | ||
else: | ||
clean_worker_cache() | ||
|
||
|
||
def get_unique(cols): | ||
# Need to preserve order in unique-column list | ||
return list({x: x for x in cols}.keys()) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
❓ I'm wondering where the client attribute is being set on the object (that this code is trying exclude). I don't see a
self.client
in here. Could be something outside this module doing something I suppose. Not suggesting we remove this now since it was here before and to reduce risk it makes sense to keep.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is leftover from a version before I realized I should use
set_client_deprecated
, so it is likely safe to remove. I know from the process of writing this that NVT tests will fail when saving a Workflow if there's a non-serializable client attribute on this object, so if it's problematic to remove, we'll find out quickly.