diff --git a/nvtabular/workflow/executor.py b/nvtabular/workflow/executor.py new file mode 100644 index 00000000000..36937f6d9ed --- /dev/null +++ b/nvtabular/workflow/executor.py @@ -0,0 +1,247 @@ +# +# Copyright (c) 2021, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import logging + +import dask +import pandas as pd +from dask.core import flatten + +from merlin.core.dispatch import concat_columns, is_list_dtype, list_val_dtype +from merlin.core.utils import ( + ensure_optimize_dataframe_graph, + global_dask_client, + set_client_deprecated, +) +from merlin.dag import ColumnSelector, Node +from merlin.io.worker import clean_worker_cache + +LOG = logging.getLogger("nvtabular") + + +class MerlinPythonExecutor: + def apply(self, df, nodes, output_dtypes=None, additional_columns=None, capture_dtypes=False): + """ + Transforms a single dataframe (possibly a partition of a Dask Dataframe) + by applying the operators from a collection of Nodes + """ + output = None + + for node in nodes: + node_input_cols = get_unique(node.input_schema.column_names) + node_output_cols = get_unique(node.output_schema.column_names) + addl_input_cols = set(node.dependency_columns.names) + + # Build input dataframe + if node.parents_with_dependencies: + # If there are parents, collect their outputs + # to build the current node's input + input_df = None + seen_columns = None + + for parent in node.parents_with_dependencies: + parent_output_cols = get_unique(parent.output_schema.column_names) + parent_df = self.apply(df, [parent], capture_dtypes=capture_dtypes) + if input_df is None or not len(input_df): + input_df = parent_df[parent_output_cols] + seen_columns = set(parent_output_cols) + else: + new_columns = set(parent_output_cols) - seen_columns + input_df = concat_columns([input_df, parent_df[list(new_columns)]]) + seen_columns.update(new_columns) + + # Check for additional input columns that aren't generated by parents + # and fetch them from the root dataframe + unseen_columns = set(node.input_schema.column_names) - seen_columns + addl_input_cols = addl_input_cols.union(unseen_columns) + + # TODO: Find a better way to remove dupes + addl_input_cols = addl_input_cols - set(input_df.columns) + + if addl_input_cols: + input_df = concat_columns([input_df, df[list(addl_input_cols)]]) + else: + # If there are no parents, this is an input node, + # so pull columns directly from root df + input_df = df[node_input_cols + list(addl_input_cols)] + + # Compute the node's output + if node.op: + try: + # use input_columns to ensure correct grouping (subgroups) + selection = node.input_columns.resolve(node.input_schema) + output_df = node.op.transform(selection, input_df) + + # Update or validate output_df dtypes + for col_name, output_col_schema in node.output_schema.column_schemas.items(): + col_series = output_df[col_name] + col_dtype = col_series.dtype + is_list = is_list_dtype(col_series) + + if is_list: + col_dtype = list_val_dtype(col_series) + + output_df_schema = output_col_schema.with_dtype( + col_dtype, is_list=is_list, is_ragged=is_list + ) + + if capture_dtypes: + node.output_schema.column_schemas[col_name] = output_df_schema + elif len(output_df): + if output_col_schema.dtype != output_df_schema.dtype: + raise TypeError( + f"Dtype discrepancy detected for column {col_name}: " + f"operator {node.op.label} reported dtype " + f"`{output_col_schema.dtype}` but returned dtype " + f"`{output_df_schema.dtype}`." + ) + except Exception: + LOG.exception("Failed to transform operator %s", node.op) + raise + if output_df is None: + raise RuntimeError(f"Operator {node.op} didn't return a value during transform") + else: + output_df = input_df + + # Combine output across node loop iterations + + # dask needs output to be in the same order defined as meta, reorder partitions here + # this also selects columns (handling the case of removing columns from the output using + # "-" overload) + if output is None: + output = output_df[node_output_cols] + else: + output = concat_columns([output, output_df[node_output_cols]]) + + if additional_columns: + output = concat_columns([output, df[get_unique(additional_columns)]]) + + return output + + +class MerlinDaskExecutor: + def __init__(self, client=None): + self._executor = MerlinPythonExecutor() + + # Deprecate `client` + if client is not None: + set_client_deprecated(client, "Workflow") + + def __getstate__(self): + # dask client objects aren't picklable - exclude from saved representation + return {k: v for k, v in self.__dict__.items() if k != "client"} + + def apply(self, ddf, nodes, output_dtypes=None, additional_columns=None, capture_dtypes=False): + """ + Transforms all partitions of a Dask Dataframe by applying the operators + from a collection of Nodes + """ + + self._clear_worker_cache() + + # Check if we are only selecting columns (no transforms). + # If so, we should perform column selection at the ddf level. + # Otherwise, Dask will not push the column selection into the + # IO function. + if not nodes: + return ddf[get_unique(additional_columns)] if additional_columns else ddf + + if isinstance(nodes, Node): + nodes = [nodes] + + columns = list(flatten(wfn.output_columns.names for wfn in nodes)) + columns += additional_columns if additional_columns else [] + + if isinstance(output_dtypes, dict) and isinstance(ddf._meta, pd.DataFrame): + dtypes = output_dtypes + output_dtypes = type(ddf._meta)({k: [] for k in columns}) + for column, dtype in dtypes.items(): + output_dtypes[column] = output_dtypes[column].astype(dtype) + + elif not output_dtypes: + # TODO: constructing meta like this loses dtype information on the ddf + # and sets it all to 'float64'. We should propagate dtype information along + # with column names in the columngroup graph. This currently only + # happesn during intermediate 'fit' transforms, so as long as statoperators + # don't require dtype information on the DDF this doesn't matter all that much + output_dtypes = type(ddf._meta)({k: [] for k in columns}) + + return ensure_optimize_dataframe_graph( + ddf=ddf.map_partitions( + self._executor.apply, + nodes, + additional_columns=additional_columns, + capture_dtypes=capture_dtypes, + meta=output_dtypes, + enforce_metadata=False, + ) + ) + + def fit(self, ddf, nodes): + """Calculates statistics for a set of nodes on the input dataframe + + Parameters + ----------- + ddf: dask.Dataframe + The input dataframe to calculate statistics for. If there is a + train/test split this should be the training dataset only. + """ + stats = [] + for node in nodes: + # Check for additional input columns that aren't generated by parents + addl_input_cols = set() + if node.parents: + upstream_output_cols = sum( + [upstream.output_columns for upstream in node.parents_with_dependencies], + ColumnSelector(), + ) + addl_input_cols = set(node.input_columns.names) - set(upstream_output_cols.names) + + # apply transforms necessary for the inputs to the current column group, ignoring + # the transforms from the statop itself + transformed_ddf = self.apply( + ddf, + node.parents_with_dependencies, + additional_columns=addl_input_cols, + capture_dtypes=True, + ) + + try: + stats.append(node.op.fit(node.input_columns, transformed_ddf)) + except Exception: + LOG.exception("Failed to fit operator %s", node.op) + raise + + dask_client = global_dask_client() + if dask_client: + results = [r.result() for r in dask_client.compute(stats)] + else: + results = dask.compute(stats, scheduler="synchronous")[0] + + for computed_stats, node in zip(results, nodes): + node.op.fit_finalize(computed_stats) + + def _clear_worker_cache(self): + # Clear worker caches to be "safe" + dask_client = global_dask_client() + if dask_client: + dask_client.run(clean_worker_cache) + else: + clean_worker_cache() + + +def get_unique(cols): + # Need to preserve order in unique-column list + return list({x: x for x in cols}.keys()) diff --git a/nvtabular/workflow/workflow.py b/nvtabular/workflow/workflow.py index 145d6c536eb..704187e0144 100755 --- a/nvtabular/workflow/workflow.py +++ b/nvtabular/workflow/workflow.py @@ -27,22 +27,13 @@ import cudf except ImportError: cudf = None -import dask import pandas as pd -from dask.core import flatten - -import nvtabular -from merlin.core.dispatch import concat_columns, is_list_dtype, list_val_dtype -from merlin.core.utils import ( - ensure_optimize_dataframe_graph, - global_dask_client, - set_client_deprecated, -) + from merlin.dag import Graph from merlin.io import Dataset -from merlin.io.worker import clean_worker_cache from merlin.schema import Schema from nvtabular.ops import StatOperator +from nvtabular.workflow.executor import MerlinDaskExecutor from nvtabular.workflow.node import WorkflowNode LOG = logging.getLogger("nvtabular") @@ -80,10 +71,8 @@ class Workflow: """ def __init__(self, output_node: WorkflowNode, client: Optional["distributed.Client"] = None): - # Deprecate `client` - if client is not None: - set_client_deprecated(client, "Workflow") self.graph = Graph(output_node) + self.executor = MerlinDaskExecutor(client) def transform(self, dataset: Dataset) -> Dataset: """Transforms the dataset by applying the graph of operators to it. Requires the ``fit`` @@ -181,7 +170,6 @@ def fit(self, dataset: Dataset) -> "Workflow": Workflow This Workflow with statistics calculated on it """ - self._clear_worker_cache() self.clear_stats() if not self.graph.output_schema: @@ -192,68 +180,28 @@ def fit(self, dataset: Dataset) -> "Workflow": # Get a dictionary mapping all StatOperators we need to fit to a set of any dependent # StatOperators (having StatOperators that depend on the output of other StatOperators # means that will have multiple phases in the fit cycle here) - stat_ops = { - op: _get_stat_ops(op.parents_with_dependencies) - for op in _get_stat_ops([self.graph.output_node]) + stat_op_nodes = { + node: Graph.get_nodes_by_op_type(node.parents_with_dependencies, StatOperator) + for node in Graph.get_nodes_by_op_type([self.graph.output_node], StatOperator) } - while stat_ops: + while stat_op_nodes: # get all the StatOperators that we can currently call fit on (no outstanding # dependencies) - current_phase = [op for op, dependencies in stat_ops.items() if not dependencies] + current_phase = [ + node for node, dependencies in stat_op_nodes.items() if not dependencies + ] if not current_phase: # this shouldn't happen, but lets not infinite loop just in case raise RuntimeError("failed to find dependency-free StatOperator to fit") - stats, ops = [], [] - for workflow_node in current_phase: - # Check for additional input columns that aren't generated by parents - addl_input_cols = set() - if workflow_node.parents: - upstream_output_cols = sum( - [ - upstream.output_columns - for upstream in workflow_node.parents_with_dependencies - ], - nvtabular.ColumnSelector(), - ) - addl_input_cols = set(workflow_node.input_columns.names) - set( - upstream_output_cols.names - ) - - # apply transforms necessary for the inputs to the current column group, ignoring - # the transforms from the statop itself - transformed_ddf = ensure_optimize_dataframe_graph( - ddf=_transform_ddf( - ddf, - workflow_node.parents_with_dependencies, - additional_columns=addl_input_cols, - capture_dtypes=True, - ) - ) - - op = workflow_node.op - try: - stats.append(op.fit(workflow_node.input_columns, transformed_ddf)) - ops.append(op) - except Exception: - LOG.exception("Failed to fit operator %s", workflow_node.op) - raise - - dask_client = global_dask_client() - if dask_client: - results = [r.result() for r in dask_client.compute(stats)] - else: - results = dask.compute(stats, scheduler="synchronous")[0] - - for computed_stats, op in zip(results, ops): - op.fit_finalize(computed_stats) + self.executor.fit(ddf, current_phase) # Remove all the operators we processed in this phase, and remove # from the dependencies of other ops too - for stat_op in current_phase: - stat_ops.pop(stat_op) - for dependencies in stat_ops.values(): + for node in current_phase: + stat_op_nodes.pop(node) + for dependencies in stat_op_nodes.values(): dependencies.difference_update(current_phase) # This captures the output dtypes of operators like LambdaOp where @@ -287,14 +235,13 @@ def fit_transform(self, dataset: Dataset) -> Dataset: return self.transform(dataset) def _transform_impl(self, dataset: Dataset, capture_dtypes=False): - self._clear_worker_cache() - if not self.graph.output_schema: self.graph.construct_schema(dataset.schema) ddf = dataset.to_ddf(columns=self._input_columns()) + return Dataset( - _transform_ddf( + self.executor.apply( ddf, self.output_node, self.output_dtypes, capture_dtypes=capture_dtypes ), cpu=dataset.cpu, @@ -319,7 +266,7 @@ def save(self, path): # point all stat ops to store intermediate output (parquet etc) at the path # this lets us easily bundle - for stat in _get_stat_ops([self.output_node]): + for stat in Graph.get_nodes_by_op_type([self.output_node], StatOperator): stat.op.set_storage_path(path, copy=True) # generate a file of all versions used to generate this bundle @@ -394,15 +341,11 @@ def check_version(stored, current, name): # we might have been copied since saving, update all the stat ops # with the new path to their storage locations - for stat in _get_stat_ops([workflow.output_node]): + for stat in Graph.get_nodes_by_op_type([workflow.output_node], StatOperator): stat.op.set_storage_path(path, copy=False) return workflow - def __getstate__(self): - # dask client objects aren't picklable - exclude from saved representation - return {k: v for k, v in self.__dict__.items() if k != "client"} - def clear_stats(self): """Removes calculated statistics from each node in the workflow graph @@ -410,156 +353,5 @@ def clear_stats(self): -------- nvtabular.ops.stat_operator.StatOperator.clear """ - for stat in _get_stat_ops([self.output_node]): + for stat in Graph.get_nodes_by_op_type([self.graph.output_node], StatOperator): stat.op.clear() - - def _clear_worker_cache(self): - # Clear worker caches to be "safe" - dask_client = global_dask_client() - if dask_client: - dask_client.run(clean_worker_cache) - else: - clean_worker_cache() - - -def _transform_ddf(ddf, workflow_nodes, meta=None, additional_columns=None, capture_dtypes=False): - # Check if we are only selecting columns (no transforms). - # If so, we should perform column selection at the ddf level. - # Otherwise, Dask will not push the column selection into the - # IO function. - if not workflow_nodes: - return ddf[_get_unique(additional_columns)] if additional_columns else ddf - - if isinstance(workflow_nodes, WorkflowNode): - workflow_nodes = [workflow_nodes] - - columns = list(flatten(wfn.output_columns.names for wfn in workflow_nodes)) - columns += additional_columns if additional_columns else [] - - if isinstance(meta, dict) and isinstance(ddf._meta, pd.DataFrame): - dtypes = meta - meta = type(ddf._meta)({k: [] for k in columns}) - for column, dtype in dtypes.items(): - meta[column] = meta[column].astype(dtype) - - elif not meta: - # TODO: constructing meta like this loses dtype information on the ddf - # and sets it all to 'float64'. We should propagate dtype information along - # with column names in the columngroup graph. This currently only - # happesn during intermediate 'fit' transforms, so as long as statoperators - # don't require dtype information on the DDF this doesn't matter all that much - meta = type(ddf._meta)({k: [] for k in columns}) - - return ddf.map_partitions( - _transform_partition, - workflow_nodes, - additional_columns=additional_columns, - capture_dtypes=capture_dtypes, - meta=meta, - enforce_metadata=False, - ) - - -def _get_stat_ops(nodes): - return Graph.get_nodes_by_op_type(nodes, StatOperator) - - -def _get_unique(cols): - # Need to preserve order in unique-column list - return list({x: x for x in cols}.keys()) - - -def _transform_partition(root_df, workflow_nodes, additional_columns=None, capture_dtypes=False): - """Transforms a single partition by appyling all operators in a WorkflowNode""" - output = None - - for node in workflow_nodes: - node_input_cols = _get_unique(node.input_schema.column_names) - node_output_cols = _get_unique(node.output_schema.column_names) - addl_input_cols = set(node.dependency_columns.names) - - # Build input dataframe - if node.parents_with_dependencies: - # If there are parents, collect their outputs - # to build the current node's input - input_df = None - seen_columns = None - - for parent in node.parents_with_dependencies: - parent_output_cols = _get_unique(parent.output_schema.column_names) - parent_df = _transform_partition(root_df, [parent], capture_dtypes=capture_dtypes) - if input_df is None or not len(input_df): - input_df = parent_df[parent_output_cols] - seen_columns = set(parent_output_cols) - else: - new_columns = set(parent_output_cols) - seen_columns - input_df = concat_columns([input_df, parent_df[list(new_columns)]]) - seen_columns.update(new_columns) - - # Check for additional input columns that aren't generated by parents - # and fetch them from the root dataframe - unseen_columns = set(node.input_schema.column_names) - seen_columns - addl_input_cols = addl_input_cols.union(unseen_columns) - - # TODO: Find a better way to remove dupes - addl_input_cols = addl_input_cols - set(input_df.columns) - - if addl_input_cols: - input_df = concat_columns([input_df, root_df[list(addl_input_cols)]]) - else: - # If there are no parents, this is an input node, - # so pull columns directly from root df - input_df = root_df[node_input_cols + list(addl_input_cols)] - - # Compute the node's output - if node.op: - try: - # use input_columns to ensure correct grouping (subgroups) - selection = node.input_columns.resolve(node.input_schema) - output_df = node.op.transform(selection, input_df) - - # Update or validate output_df dtypes - for col_name, output_col_schema in node.output_schema.column_schemas.items(): - col_series = output_df[col_name] - col_dtype = col_series.dtype - is_list = is_list_dtype(col_series) - - if is_list: - col_dtype = list_val_dtype(col_series) - - output_df_schema = output_col_schema.with_dtype( - col_dtype, is_list=is_list, is_ragged=is_list - ) - - if capture_dtypes: - node.output_schema.column_schemas[col_name] = output_df_schema - elif len(output_df): - if output_col_schema.dtype != output_df_schema.dtype: - raise TypeError( - f"Dtype discrepancy detected for column {col_name}: " - f"operator {node.op.label} reported dtype " - f"`{output_col_schema.dtype}` but returned dtype " - f"`{output_df_schema.dtype}`." - ) - except Exception: - LOG.exception("Failed to transform operator %s", node.op) - raise - if output_df is None: - raise RuntimeError("Operator %s didn't return a value during transform" % node.op) - else: - output_df = input_df - - # Combine output across node loop iterations - - # dask needs output to be in the same order defined as meta, reorder partitions here - # this also selects columns (handling the case of removing columns from the output using - # "-" overload) - if output is None: - output = output_df[node_output_cols] - else: - output = concat_columns([output, output_df[node_output_cols]]) - - if additional_columns: - output = concat_columns([output, root_df[_get_unique(additional_columns)]]) - - return output