From d2737da6cc7e98ec8723731a7cc998aa143e6758 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Tue, 1 Nov 2022 16:35:06 +0000 Subject: [PATCH] Refactor TensorFlow operator into base and Triton runtime operator (#219) * Split out PredictTensorflow into two Operators Co-authored-by: Karl Higley Co-authored-by: Julio Perez <37191411+jperez999@users.noreply.github.com> * raise NotImplementedError in InferenceOperator export abstractmethod Co-authored-by: Julio Perez <37191411+jperez999@users.noreply.github.com> Co-authored-by: Karl Higley * Move Tensorflow Triton op to Triton Runtime module Co-authored-by: Julio Perez <37191411+jperez999@users.noreply.github.com> Co-authored-by: Karl Higley * Move triton runtime code to subdirectory Co-authored-by: Julio Perez <37191411+jperez999@users.noreply.github.com> Co-authored-by: Karl Higley * Use importlib.resources to find oprunner model * Use importlib.resources to find workflow_model.py * Move TritonOperator to operator module * Move `_tf_model_name` from base tensorflow op to triton op * Construct InferenceResponse correctly when an error occurs * Add docstrings to TritonOperator * Remove `exportable_backends` from `PredictTensorflow`, add docstrings Co-authored-by: Karl Higley Co-authored-by: Julio Perez <37191411+jperez999@users.noreply.github.com> --- merlin/systems/dag/ops/operator.py | 15 +- merlin/systems/dag/ops/tensorflow.py | 148 +++------------- merlin/systems/dag/runtimes/base_runtime.py | 1 + .../systems/dag/runtimes/triton/__init__.py | 20 +++ .../dag/runtimes/triton/ops/__init__.py | 15 ++ .../dag/runtimes/triton/ops/operator.py | 55 ++++++ .../dag/runtimes/triton/ops/tensorflow.py | 165 ++++++++++++++++++ .../runtimes/{triton.py => triton/runtime.py} | 51 ++++-- merlin/systems/triton/export.py | 31 ++-- merlin/systems/triton/models/__init__.py | 15 ++ .../systems/triton/models/executor_model.py | 3 +- .../systems/triton/models/oprunner_model.py | 3 +- tests/unit/systems/ops/tf/test_op.py | 11 +- 13 files changed, 375 insertions(+), 158 deletions(-) create mode 100644 merlin/systems/dag/runtimes/triton/__init__.py create mode 100644 merlin/systems/dag/runtimes/triton/ops/__init__.py create mode 100644 merlin/systems/dag/runtimes/triton/ops/operator.py create mode 100644 merlin/systems/dag/runtimes/triton/ops/tensorflow.py rename merlin/systems/dag/runtimes/{triton.py => triton/runtime.py} (91%) create mode 100644 merlin/systems/triton/models/__init__.py diff --git a/merlin/systems/dag/ops/operator.py b/merlin/systems/dag/ops/operator.py index 667fc3b5e..0b49a66cb 100644 --- a/merlin/systems/dag/ops/operator.py +++ b/merlin/systems/dag/ops/operator.py @@ -1,3 +1,4 @@ +import importlib.resources import json import os import pathlib @@ -113,6 +114,7 @@ def export( Node_configs: list A list of individual configs for each step (operator) in graph. """ + raise NotImplementedError def create_node(self, selector: ColumnSelector) -> InferenceNode: """_summary_ @@ -288,12 +290,13 @@ def export( os.makedirs(node_export_path, exist_ok=True) os.makedirs(os.path.join(node_export_path, str(version)), exist_ok=True) - copyfile( - os.path.join( - os.path.dirname(__file__), "..", "..", "triton", "models", "oprunner_model.py" - ), - os.path.join(node_export_path, str(version), "model.py"), - ) + with importlib.resources.path( + "merlin.systems.triton.models", "oprunner_model.py" + ) as oprunner_model: + copyfile( + oprunner_model, + os.path.join(node_export_path, str(version), "model.py"), + ) return config diff --git a/merlin/systems/dag/ops/tensorflow.py b/merlin/systems/dag/ops/tensorflow.py index 376c7e958..a0640cce9 100644 --- a/merlin/systems/dag/ops/tensorflow.py +++ b/merlin/systems/dag/ops/tensorflow.py @@ -16,28 +16,20 @@ import os import pathlib import tempfile -from shutil import copytree # this needs to be before any modules that import protobuf os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" import tensorflow as tf # noqa -import tritonclient.grpc.model_config_pb2 as model_config # noqa -from google.protobuf import text_format # noqa from merlin.core.protocols import Transformable # noqa from merlin.dag import ColumnSelector # noqa from merlin.schema import ColumnSchema, Schema # noqa -from merlin.systems.dag.ops import compute_dims # noqa -from merlin.systems.dag.ops.compat import pb_utils # noqa -from merlin.systems.dag.ops.operator import PipelineableInferenceOperator, add_model_param # noqa +from merlin.systems.dag.ops.operator import InferenceOperator # noqa -class PredictTensorflow(PipelineableInferenceOperator): - """ - This operator takes a tensorflow model and packages it correctly for tritonserver - to run, on the tensorflow backend. - """ +class PredictTensorflow(InferenceOperator): + """TensorFlow Model Prediction Operator.""" def __init__(self, model_or_path, custom_objects: dict = None, backend="tensorflow"): """ @@ -51,7 +43,6 @@ def __init__(self, model_or_path, custom_objects: dict = None, backend="tensorfl Any custom objects that need to be loaded with the model, by default None. """ super().__init__() - self._tf_model_name = None if model_or_path is not None: custom_objects = custom_objects or {} @@ -65,84 +56,38 @@ def __init__(self, model_or_path, custom_objects: dict = None, backend="tensorfl self.input_schema, self.output_schema = self._construct_schemas_from_model(self.model) - def __getstate__(self): - return {k: v for k, v in self.__dict__.items() if k != "model"} - - @property - def tf_model_name(self): - return self._tf_model_name + def __getstate__(self) -> dict: + """Return state of instance when pickled. - def set_tf_model_name(self, tf_model_name: str): + Returns + ------- + dict + Returns object state excluding model attribute. """ - Set the name of the Triton model to use + return {k: v for k, v in self.__dict__.items() if k != "model"} + + def transform( + self, col_selector: ColumnSelector, transformable: Transformable + ) -> Transformable: + """Run model inference. Returning predictions. Parameters ---------- - tf_model_name : str - Triton model directory name - """ - self._tf_model_name = tf_model_name + col_selector : ColumnSelector + Unused ColumunSelector input + transformable : Transformable + Input features to model - def transform(self, col_selector: ColumnSelector, transformable: Transformable): + Returns + ------- + Transformable + Model Predictions + """ # TODO: Validate that the inputs match the schema # TODO: Should we coerce the dtypes to match the schema here? - input_tensors = [] - for col_name in self.input_schema.column_schemas.keys(): - input_tensors.append(pb_utils.Tensor(col_name, transformable[col_name])) - - inference_request = pb_utils.InferenceRequest( - model_name=self.tf_model_name, - requested_output_names=self.output_schema.column_names, - inputs=input_tensors, - ) - inference_response = inference_request.exec() - - # TODO: Validate that the outputs match the schema - outputs_dict = {} - for out_col_name in self.output_schema.column_schemas.keys(): - output_val = pb_utils.get_output_tensor_by_name( - inference_response, out_col_name - ).as_numpy() - outputs_dict[out_col_name] = output_val - - return type(transformable)(outputs_dict) - - @property - def exportable_backends(self): - return ["ensemble", "executor"] - - def export( - self, - path: str, - input_schema: Schema, - output_schema: Schema, - params: dict = None, - node_id: int = None, - version: int = 1, - backend: str = "ensemble", - ): - """Create a directory inside supplied path based on our export name""" - # Export Triton TF back-end directory and config etc - export_name = self.__class__.__name__.lower() - node_name = f"{node_id}_{export_name}" if node_id is not None else export_name - - node_export_path = pathlib.Path(path) / node_name - node_export_path.mkdir(exist_ok=True) - - tf_model_path = pathlib.Path(node_export_path) / str(version) / "model.savedmodel" - - if self.path: - copytree( - str(self.path), - tf_model_path, - dirs_exist_ok=True, - ) - else: - self.model.save(tf_model_path, include_optimizer=False) - - self.set_tf_model_name(node_name) - backend_model_config = self._export_model_config(node_name, node_export_path) - return backend_model_config + output = self.model(transformable) + # TODO: map output schema names to outputs produced by prediction + return type(transformable)({"output": output}) @property def export_name(self): @@ -180,45 +125,6 @@ def compute_output_schema( """ return self.output_schema - def _export_model_config(self, name, output_path): - """Exports a TensorFlow model for serving with Triton - - Parameters - ---------- - model: - The tensorflow model that should be served - name: - The name of the triton model to export - output_path: - The path to write the exported model to - """ - config = model_config.ModelConfig( - name=name, backend="tensorflow", platform="tensorflow_savedmodel" - ) - - config.parameters["TF_GRAPH_TAG"].string_value = "serve" - config.parameters["TF_SIGNATURE_DEF"].string_value = "serving_default" - - for _, col_schema in self.input_schema.column_schemas.items(): - add_model_param( - config.input, - model_config.ModelInput, - col_schema, - compute_dims(col_schema, self.scalar_shape), - ) - - for _, col_schema in self.output_schema.column_schemas.items(): - add_model_param( - config.output, - model_config.ModelOutput, - col_schema, - compute_dims(col_schema, self.scalar_shape), - ) - - with open(os.path.join(output_path, "config.pbtxt"), "w", encoding="utf-8") as o: - text_format.PrintMessage(config, o) - return config - def _construct_schemas_from_model(self, model): signatures = getattr(model, "signatures", {}) or {} default_signature = signatures.get("serving_default") diff --git a/merlin/systems/dag/runtimes/base_runtime.py b/merlin/systems/dag/runtimes/base_runtime.py index 2c6995e2b..c68eeaadb 100644 --- a/merlin/systems/dag/runtimes/base_runtime.py +++ b/merlin/systems/dag/runtimes/base_runtime.py @@ -33,6 +33,7 @@ def __init__(self, executor=None): The Graph Executor to use to use for the transform, by default None """ self.executor = executor or LocalExecutor() + self.op_table = {} def transform(self, graph: Graph, transformable: Transformable): """Run the graph with the input data. diff --git a/merlin/systems/dag/runtimes/triton/__init__.py b/merlin/systems/dag/runtimes/triton/__init__.py new file mode 100644 index 000000000..dd92e639d --- /dev/null +++ b/merlin/systems/dag/runtimes/triton/__init__.py @@ -0,0 +1,20 @@ +# +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# flake8: noqa +from merlin.systems.dag.runtimes.triton.runtime import ( # noqa + TritonEnsembleRuntime, + TritonExecutorRuntime, +) diff --git a/merlin/systems/dag/runtimes/triton/ops/__init__.py b/merlin/systems/dag/runtimes/triton/ops/__init__.py new file mode 100644 index 000000000..0b8ff56d3 --- /dev/null +++ b/merlin/systems/dag/runtimes/triton/ops/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/merlin/systems/dag/runtimes/triton/ops/operator.py b/merlin/systems/dag/runtimes/triton/ops/operator.py new file mode 100644 index 000000000..9038f3c51 --- /dev/null +++ b/merlin/systems/dag/runtimes/triton/ops/operator.py @@ -0,0 +1,55 @@ +# +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import List + +from merlin.systems.dag.ops.operator import InferenceOperator + + +class TritonOperator: + """Base class for Triton operators.""" + + def __init__(self, base_op: InferenceOperator): + """Construct TritonOperator from a base operator. + + Parameters + ---------- + base_op : merlin.systems.dag.ops.operator.InfereneOperator + Base operator used to construct this Triton op. + """ + self.op = base_op + + @property + def export_name(self): + """ + Provides a clear common english identifier for this operator. + + Returns + ------- + String + Name of the current class as spelled in module. + """ + return self.__class__.__name__.lower() + + @property + def exportable_backends(self) -> List[str]: + """Returns list of supported backends. + + Returns + ------- + List[str] + List of supported backends + """ + return ["ensemble", "executor"] diff --git a/merlin/systems/dag/runtimes/triton/ops/tensorflow.py b/merlin/systems/dag/runtimes/triton/ops/tensorflow.py new file mode 100644 index 000000000..f57478a26 --- /dev/null +++ b/merlin/systems/dag/runtimes/triton/ops/tensorflow.py @@ -0,0 +1,165 @@ +# +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import pathlib +from shutil import copytree + +# this needs to be before any modules that import protobuf +os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" + +import tritonclient.grpc.model_config_pb2 as model_config # noqa +from google.protobuf import text_format # noqa + +from merlin.core.protocols import Transformable # noqa +from merlin.dag import ColumnSelector # noqa +from merlin.schema import Schema # noqa +from merlin.systems.dag.ops import compute_dims # noqa +from merlin.systems.dag.ops.compat import pb_utils # noqa +from merlin.systems.dag.ops.operator import add_model_param # noqa +from merlin.systems.dag.runtimes.triton.ops.operator import TritonOperator # noqa + + +class PredictTensorflowTriton(TritonOperator): + """TensorFlow Model Prediction Operator for running inside Triton.""" + + def __init__(self, op): + super().__init__(op) + + self.input_schema = op.input_schema + self.output_schema = op.output_schema + self.path = op.path + self.model = op.model + self.scalar_shape = op.scalar_shape + + self._tf_model_name = None + + def transform(self, col_selector: ColumnSelector, transformable: Transformable): + """Run transform of operator callling TensorFlow model with a Triton InferenceRequest. + + Returns + ------- + Transformable + TensorFlow Model Outputs + """ + # TODO: Validate that the inputs match the schema + # TODO: Should we coerce the dtypes to match the schema here? + input_tensors = [] + for col_name in self.input_schema.column_schemas.keys(): + input_tensors.append(pb_utils.Tensor(col_name, transformable[col_name])) + + inference_request = pb_utils.InferenceRequest( + model_name=self.tf_model_name, + requested_output_names=self.output_schema.column_names, + inputs=input_tensors, + ) + inference_response = inference_request.exec() + + # TODO: Validate that the outputs match the schema + outputs_dict = {} + for out_col_name in self.output_schema.column_schemas.keys(): + output_val = pb_utils.get_output_tensor_by_name( + inference_response, out_col_name + ).as_numpy() + outputs_dict[out_col_name] = output_val + + return type(transformable)(outputs_dict) + + def export( + self, + path: str, + input_schema: Schema, + output_schema: Schema, + params: dict = None, + node_id: int = None, + version: int = 1, + backend: str = "ensemble", + ): + """Create a directory inside supplied path based on our export name""" + # Export Triton TF back-end directory and config etc + export_name = self.__class__.__name__.lower() + node_name = f"{node_id}_{export_name}" if node_id is not None else export_name + + node_export_path = pathlib.Path(path) / node_name + node_export_path.mkdir(exist_ok=True) + + tf_model_path = pathlib.Path(node_export_path) / str(version) / "model.savedmodel" + + if self.path: + copytree( + str(self.path), + tf_model_path, + dirs_exist_ok=True, + ) + else: + self.model.save(tf_model_path, include_optimizer=False) + + self.set_tf_model_name(node_name) + backend_model_config = self._export_model_config(node_name, node_export_path) + return backend_model_config + + def _export_model_config(self, name, output_path): + """Exports a TensorFlow model for serving with Triton + + Parameters + ---------- + model: + The tensorflow model that should be served + name: + The name of the triton model to export + output_path: + The path to write the exported model to + """ + config = model_config.ModelConfig( + name=name, backend="tensorflow", platform="tensorflow_savedmodel" + ) + + config.parameters["TF_GRAPH_TAG"].string_value = "serve" + config.parameters["TF_SIGNATURE_DEF"].string_value = "serving_default" + + for _, col_schema in self.input_schema.column_schemas.items(): + add_model_param( + config.input, + model_config.ModelInput, + col_schema, + compute_dims(col_schema, self.scalar_shape), + ) + + for _, col_schema in self.output_schema.column_schemas.items(): + add_model_param( + config.output, + model_config.ModelOutput, + col_schema, + compute_dims(col_schema, self.scalar_shape), + ) + + with open(os.path.join(output_path, "config.pbtxt"), "w", encoding="utf-8") as o: + text_format.PrintMessage(config, o) + return config + + @property + def tf_model_name(self): + return self._tf_model_name + + def set_tf_model_name(self, tf_model_name: str): + """ + Set the name of the Triton model to use + + Parameters + ---------- + tf_model_name : str + Triton model directory name + """ + self._tf_model_name = tf_model_name diff --git a/merlin/systems/dag/runtimes/triton.py b/merlin/systems/dag/runtimes/triton/runtime.py similarity index 91% rename from merlin/systems/dag/runtimes/triton.py rename to merlin/systems/dag/runtimes/triton/runtime.py index 7b98c2acf..2b47b9817 100644 --- a/merlin/systems/dag/runtimes/triton.py +++ b/merlin/systems/dag/runtimes/triton/runtime.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import importlib.resources import os import pathlib from shutil import copyfile @@ -31,11 +32,28 @@ from merlin.systems.dag.ops import compute_dims # noqa from merlin.systems.dag.ops.operator import add_model_param # noqa from merlin.systems.dag.runtimes import Runtime # noqa +from merlin.systems.dag.runtimes.triton.ops.tensorflow import PredictTensorflowTriton # noqa + +tensorflow = None +try: + import tensorflow +except ImportError: + ... + +TRITON_OP_TABLE = {} +if tensorflow: + from merlin.systems.dag.ops.tensorflow import PredictTensorflow + + TRITON_OP_TABLE[PredictTensorflow] = PredictTensorflowTriton class TritonEnsembleRuntime(Runtime): """Runtime for Triton. Runs each operator in DAG as a separate model in a Triton Ensemble.""" + def __init__(self): + super().__init__() + self.op_table = TRITON_OP_TABLE + def transform(self, graph: Graph, transformable: Transformable): raise NotImplementedError("Transform handled by Triton") @@ -67,7 +85,13 @@ def export( """ name = name or "ensemble_model" # Build node id lookup table + nodes = list(postorder_iter_nodes(ensemble.graph.output_node)) + + for node in nodes: + if type(node.op) in self.op_table: + node.op = self.op_table[type(node.op)](node.op) + node_id_table, num_nodes = _create_node_table(nodes, "ensemble") nodes = nodes or [] @@ -185,6 +209,10 @@ class TritonExecutorRuntime(Runtime): Triton models for nodes that use any non-python backends. """ + def __init__(self): + super().__init__() + self.op_table = TRITON_OP_TABLE + def export( self, ensemble, path: str, version: int = 1, name: str = None ) -> Tuple[model_config.ModelConfig, List[model_config.ModelConfig]]: @@ -214,6 +242,11 @@ def export( name = name or "executor_model" nodes = list(postorder_iter_nodes(ensemble.graph.output_node)) + + for node in nodes: + if type(node.op) in self.op_table: + node.op = self.op_table[type(node.op)](node.op) + node_id_table, _ = _create_node_table(nodes, "executor") node_configs = [] @@ -290,17 +323,13 @@ def _executor_model_export( os.makedirs(node_export_path, exist_ok=True) os.makedirs(os.path.join(node_export_path, str(version)), exist_ok=True) - copyfile( - os.path.join( - os.path.dirname(__file__), - "..", - "..", - "triton", - "models", - "executor_model.py", - ), - os.path.join(node_export_path, str(version), "model.py"), - ) + with importlib.resources.path( + "merlin.systems.triton.models", "executor_model.py" + ) as executor_model: + copyfile( + executor_model, + os.path.join(node_export_path, str(version), "model.py"), + ) ensemble.save(os.path.join(node_export_path, str(version), "ensemble")) diff --git a/merlin/systems/triton/export.py b/merlin/systems/triton/export.py index 4e96ca2b9..1c006a35d 100644 --- a/merlin/systems/triton/export.py +++ b/merlin/systems/triton/export.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import importlib.resources import json import os import warnings @@ -317,7 +317,7 @@ def _generate_ensemble_config(name, output_path, nvt_config, nn_config, name_ext config.ensemble_scheduling.step.append(nvt_step) config.ensemble_scheduling.step.append(tf_step) - with open(os.path.join(output_path, "config.pbtxt"), "w") as o: + with open(os.path.join(output_path, "config.pbtxt"), "w", encoding="utf-8") as o: text_format.PrintMessage(config, o) return config @@ -359,10 +359,13 @@ def generate_nvtabular_model( # copy the model file over. note that this isn't necessary with the c++ backend, but # does provide us to use the python backend with just changing the 'backend' parameter - copyfile( - os.path.join(os.path.dirname(__file__), "models", "workflow_model.py"), - os.path.join(output_path, str(version), "model.py"), - ) + with importlib.resources.path( + "merlin.systems.triton.models", "workflow_model.py" + ) as workflow_model: + copyfile( + workflow_model, + os.path.join(output_path, str(version), "model.py"), + ) return config @@ -456,7 +459,7 @@ def _generate_nvtabular_config( else: _add_model_param(col_schema, model_config.ModelOutput, config.output) - with open(os.path.join(output_path, "config.pbtxt"), "w") as o: + with open(os.path.join(output_path, "config.pbtxt"), "w", encoding="utf-8") as o: text_format.PrintMessage(config, o) return config @@ -514,7 +517,7 @@ def export_tensorflow_model(model, name, output_path, version=1): ) ) - with open(os.path.join(output_path, "config.pbtxt"), "w") as o: + with open(os.path.join(output_path, "config.pbtxt"), "w", encoding="utf-8") as o: text_format.PrintMessage(config, o) return config @@ -576,13 +579,15 @@ def export_pytorch_model( ) if sparse_max: - with open(os.path.join(output_path, str(version), "model_info.json"), "w") as o: - model_info = dict() + with open( + os.path.join(output_path, str(version), "model_info.json"), "w", encoding="utf-8" + ) as o: + model_info = {} model_info["sparse_max"] = sparse_max model_info["use_fix_dtypes"] = use_fix_dtypes json.dump(model_info, o) - with open(os.path.join(output_path, "config.pbtxt"), "w") as o: + with open(os.path.join(output_path, "config.pbtxt"), "w", encoding="utf-8") as o: text_format.PrintMessage(config, o) return config @@ -604,7 +609,7 @@ def _generate_pytorch_config(model, name, output_path, max_batch_size=None): ) ) - with open(os.path.join(output_path, "config.pbtxt"), "w") as o: + with open(os.path.join(output_path, "config.pbtxt"), "w", encoding="utf-8") as o: text_format.PrintMessage(config, o) return config @@ -675,7 +680,7 @@ def _generate_hugectr_config(name, output_path, hugectr_params, max_batch_size=N embeddingkey_long_type = model_config.ModelParameter(string_value=embeddingkey_long_type_val) config.parameters["embeddingkey_long_type"].CopyFrom(embeddingkey_long_type) - with open(os.path.join(output_path, "config.pbtxt"), "w") as o: + with open(os.path.join(output_path, "config.pbtxt"), "w", encoding="utf-8") as o: text_format.PrintMessage(config, o) return config diff --git a/merlin/systems/triton/models/__init__.py b/merlin/systems/triton/models/__init__.py new file mode 100644 index 000000000..0b8ff56d3 --- /dev/null +++ b/merlin/systems/triton/models/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/merlin/systems/triton/models/executor_model.py b/merlin/systems/triton/models/executor_model.py index c29ded7d6..e08316ee4 100644 --- a/merlin/systems/triton/models/executor_model.py +++ b/merlin/systems/triton/models/executor_model.py @@ -119,7 +119,8 @@ def execute(self, requests): tb_string = repr(traceback.extract_tb(exc_traceback)) responses.append( pb_utils.InferenceResponse( - tensors=[], error=f"{exc_type}, {exc_value}, {tb_string}" + output_tensors=[], + error=pb_utils.TritonError(f"{exc_type}, {exc_value}, {tb_string}"), ) ) diff --git a/merlin/systems/triton/models/oprunner_model.py b/merlin/systems/triton/models/oprunner_model.py index ded2ac71e..eb6f06553 100644 --- a/merlin/systems/triton/models/oprunner_model.py +++ b/merlin/systems/triton/models/oprunner_model.py @@ -120,7 +120,8 @@ def execute(self, requests): tb_string = repr(traceback.extract_tb(exc_traceback)) responses.append( pb_utils.InferenceResponse( - tensors=[], error=f"{exc_type}, {exc_value}, {tb_string}" + output_tensors=[], + error=pb_utils.TritonError(f"{exc_type}, {exc_value}, {tb_string}"), ) ) diff --git a/tests/unit/systems/ops/tf/test_op.py b/tests/unit/systems/ops/tf/test_op.py index 2820ed392..9fc893bb3 100644 --- a/tests/unit/systems/ops/tf/test_op.py +++ b/tests/unit/systems/ops/tf/test_op.py @@ -31,6 +31,7 @@ from tritonclient.grpc import model_config_pb2 as model_config # noqa tf_op = pytest.importorskip("merlin.systems.dag.ops.tensorflow") +tf_triton_op = pytest.importorskip("merlin.systems.dag.runtimes.triton.ops.tensorflow") tf = pytest.importorskip("tensorflow") @@ -55,7 +56,8 @@ def test_tf_op_exports_own_config(tmpdir): output_schema = Schema([ColumnSchema("output", dtype=np.float32)]) # Triton - triton_op = tf_op.PredictTensorflow(model) + tf_model_op = tf_op.PredictTensorflow(model) + triton_op = tf_triton_op.PredictTensorflowTriton(tf_model_op) triton_op.export(tmpdir, input_schema, output_schema) # Export creates directory @@ -145,9 +147,8 @@ def test_tf_op_infers_schema_for_input_tuples(): metrics=[tf.metrics.SparseCategoricalAccuracy()], ) - # Triton - triton_op = tf_op.PredictTensorflow(model) - assert triton_op.input_schema == Schema( + op = tf_op.PredictTensorflow(model) + assert op.input_schema == Schema( [ ColumnSchema( name="input_1", @@ -166,6 +167,6 @@ def test_tf_op_infers_schema_for_input_tuples(): ), ] ) - assert triton_op.output_schema == Schema( + assert op.output_schema == Schema( [ColumnSchema("dot", dtype=np.float32, is_list=False, is_ragged=False)] )