Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable cpp code path for Categorify ops #389

Merged
merged 9 commits into from
Jun 11, 2024
17 changes: 11 additions & 6 deletions merlin/systems/workflow/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,20 @@ def __init__(self, workflow, output_dtypes, model_config, model_device):
)

# recurse over all column groups, initializing operators for inference pipeline.
# (disabled for now while we sort out whether and how we want to use C++ implementations
# of NVTabular operators for performance optimization)
# self._initialize_ops(self.workflow.output_node)
# (disabled everyting other than Categorify for now while we sort out whether
# and how we want to use C++ implementations of NVTabular operators for
# performance optimization)
self._initialize_ops(self.workflow.output_node, restrict=["Categorify"])

def _initialize_ops(self, workflow_node, visited=None, restrict=None):
restrict = restrict or []

def _initialize_ops(self, workflow_node, visited=None):
if visited is None:
visited = set()

if workflow_node.op and hasattr(workflow_node.op, "inference_initialize"):
if workflow_node.op and hasattr(workflow_node.op, "inference_initialize") and (
not restrict or workflow_node.op.label in restrict
):
inference_op = workflow_node.op.inference_initialize(
workflow_node.selector, self.model_config
)
Expand All @@ -96,7 +101,7 @@ def _initialize_ops(self, workflow_node, visited=None):
for parent in workflow_node.parents_with_dependencies:
if parent not in visited:
visited.add(parent)
self._initialize_ops(parent, visited)
self._initialize_ops(parent, visited=visited, restrict=restrict)

def run_workflow(self, input_tensors):
transformable = TensorTable(input_tensors).to_df()
Expand Down
Loading