diff --git a/merlin/systems/dag/ops/faiss.py b/merlin/systems/dag/ops/faiss.py index 1e7630645..dcd251431 100644 --- a/merlin/systems/dag/ops/faiss.py +++ b/merlin/systems/dag/ops/faiss.py @@ -202,11 +202,6 @@ def compute_input_schema( input_schema = super().compute_input_schema( root_schema, parents_schema, deps_schema, selector ) - if len(input_schema.column_schemas) > 1: - raise ValueError( - "More than one input has been detected for this node," - / f"inputs received: {input_schema.column_names}" - ) return input_schema def compute_output_schema( @@ -237,6 +232,15 @@ def compute_output_schema( ] ) + def validate_schemas( + self, parents_schema, deps_schema, input_schema, output_schema, strict_dtypes=False + ): + if len(input_schema.column_schemas) > 1: + raise ValueError( + "More than one input has been detected for this node," + / f"inputs received: {input_schema.column_names}" + ) + def setup_faiss(item_vector, output_path: str): """ diff --git a/merlin/systems/dag/ops/session_filter.py b/merlin/systems/dag/ops/session_filter.py index 2b1a85c4e..914ac7af4 100644 --- a/merlin/systems/dag/ops/session_filter.py +++ b/merlin/systems/dag/ops/session_filter.py @@ -110,24 +110,6 @@ def compute_input_schema( root_schema, parents_schema, deps_schema, selector ) - if len(parents_schema.column_schemas) > 1: - raise ValueError( - "More than one input has been detected for this node," - / f"inputs received: {input_schema.column_names}" - ) - if len(deps_schema.column_schemas) > 1: - raise ValueError( - "More than one dependency input has been detected" - / f"for this node, inputs received: {input_schema.column_names}" - ) - - # 1 for deps and 1 for parents - if len(input_schema.column_schemas) > 2: - raise ValueError( - "More than one input has been detected for this node," - / f"inputs received: {input_schema.column_names}" - ) - self._input_col = parents_schema.column_names[0] self._filter_out_col = deps_schema.column_names[0] @@ -157,6 +139,27 @@ def compute_output_schema( """ return Schema([ColumnSchema("filtered_ids", dtype=np.int32, is_list=False)]) + def validate_schemas( + self, parents_schema, deps_schema, input_schema, output_schema, strict_dtypes=False + ): + if len(parents_schema.column_schemas) > 1: + raise ValueError( + "More than one input has been detected for this node," + / f"inputs received: {input_schema.column_names}" + ) + if len(deps_schema.column_schemas) > 1: + raise ValueError( + "More than one dependency input has been detected" + / f"for this node, inputs received: {input_schema.column_names}" + ) + + # 1 for deps and 1 for parents + if len(input_schema.column_schemas) > 2: + raise ValueError( + "More than one input has been detected for this node," + / f"inputs received: {input_schema.column_names}" + ) + def transform(self, df: InferenceDataFrame): """ Transform input dataframe to output dataframe using function logic.