Skip to content

Commit

Permalink
Refactor TensorFlow operator into base and Triton runtime operator (#219
Browse files Browse the repository at this point in the history
)

* Split out PredictTensorflow into two Operators

Co-authored-by: Karl Higley <[email protected]>
Co-authored-by: Julio Perez <[email protected]>

* raise NotImplementedError in InferenceOperator export abstractmethod

Co-authored-by: Julio Perez <[email protected]>
Co-authored-by: Karl Higley <[email protected]>

* Move Tensorflow Triton op to Triton Runtime module

Co-authored-by: Julio Perez <[email protected]>
Co-authored-by: Karl Higley <[email protected]>

* Move triton runtime code to subdirectory

Co-authored-by: Julio Perez <[email protected]>
Co-authored-by: Karl Higley <[email protected]>

* Use importlib.resources to find oprunner model

* Use importlib.resources to find workflow_model.py

* Move TritonOperator to operator module

* Move `_tf_model_name` from base tensorflow op to triton op

* Construct InferenceResponse correctly when an error occurs

* Add docstrings to TritonOperator

* Remove `exportable_backends` from `PredictTensorflow`, add docstrings

Co-authored-by: Karl Higley <[email protected]>
Co-authored-by: Julio Perez <[email protected]>
  • Loading branch information
3 people authored Nov 1, 2022
1 parent dff7b02 commit d2737da
Show file tree
Hide file tree
Showing 13 changed files with 375 additions and 158 deletions.
15 changes: 9 additions & 6 deletions merlin/systems/dag/ops/operator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import importlib.resources
import json
import os
import pathlib
Expand Down Expand Up @@ -113,6 +114,7 @@ def export(
Node_configs: list
A list of individual configs for each step (operator) in graph.
"""
raise NotImplementedError

def create_node(self, selector: ColumnSelector) -> InferenceNode:
"""_summary_
Expand Down Expand Up @@ -288,12 +290,13 @@ def export(

os.makedirs(node_export_path, exist_ok=True)
os.makedirs(os.path.join(node_export_path, str(version)), exist_ok=True)
copyfile(
os.path.join(
os.path.dirname(__file__), "..", "..", "triton", "models", "oprunner_model.py"
),
os.path.join(node_export_path, str(version), "model.py"),
)
with importlib.resources.path(
"merlin.systems.triton.models", "oprunner_model.py"
) as oprunner_model:
copyfile(
oprunner_model,
os.path.join(node_export_path, str(version), "model.py"),
)

return config

Expand Down
148 changes: 27 additions & 121 deletions merlin/systems/dag/ops/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,20 @@
import os
import pathlib
import tempfile
from shutil import copytree

# this needs to be before any modules that import protobuf
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"

import tensorflow as tf # noqa
import tritonclient.grpc.model_config_pb2 as model_config # noqa
from google.protobuf import text_format # noqa

from merlin.core.protocols import Transformable # noqa
from merlin.dag import ColumnSelector # noqa
from merlin.schema import ColumnSchema, Schema # noqa
from merlin.systems.dag.ops import compute_dims # noqa
from merlin.systems.dag.ops.compat import pb_utils # noqa
from merlin.systems.dag.ops.operator import PipelineableInferenceOperator, add_model_param # noqa
from merlin.systems.dag.ops.operator import InferenceOperator # noqa


class PredictTensorflow(PipelineableInferenceOperator):
"""
This operator takes a tensorflow model and packages it correctly for tritonserver
to run, on the tensorflow backend.
"""
class PredictTensorflow(InferenceOperator):
"""TensorFlow Model Prediction Operator."""

def __init__(self, model_or_path, custom_objects: dict = None, backend="tensorflow"):
"""
Expand All @@ -51,7 +43,6 @@ def __init__(self, model_or_path, custom_objects: dict = None, backend="tensorfl
Any custom objects that need to be loaded with the model, by default None.
"""
super().__init__()
self._tf_model_name = None

if model_or_path is not None:
custom_objects = custom_objects or {}
Expand All @@ -65,84 +56,38 @@ def __init__(self, model_or_path, custom_objects: dict = None, backend="tensorfl

self.input_schema, self.output_schema = self._construct_schemas_from_model(self.model)

def __getstate__(self):
return {k: v for k, v in self.__dict__.items() if k != "model"}

@property
def tf_model_name(self):
return self._tf_model_name
def __getstate__(self) -> dict:
"""Return state of instance when pickled.
def set_tf_model_name(self, tf_model_name: str):
Returns
-------
dict
Returns object state excluding model attribute.
"""
Set the name of the Triton model to use
return {k: v for k, v in self.__dict__.items() if k != "model"}

def transform(
self, col_selector: ColumnSelector, transformable: Transformable
) -> Transformable:
"""Run model inference. Returning predictions.
Parameters
----------
tf_model_name : str
Triton model directory name
"""
self._tf_model_name = tf_model_name
col_selector : ColumnSelector
Unused ColumunSelector input
transformable : Transformable
Input features to model
def transform(self, col_selector: ColumnSelector, transformable: Transformable):
Returns
-------
Transformable
Model Predictions
"""
# TODO: Validate that the inputs match the schema
# TODO: Should we coerce the dtypes to match the schema here?
input_tensors = []
for col_name in self.input_schema.column_schemas.keys():
input_tensors.append(pb_utils.Tensor(col_name, transformable[col_name]))

inference_request = pb_utils.InferenceRequest(
model_name=self.tf_model_name,
requested_output_names=self.output_schema.column_names,
inputs=input_tensors,
)
inference_response = inference_request.exec()

# TODO: Validate that the outputs match the schema
outputs_dict = {}
for out_col_name in self.output_schema.column_schemas.keys():
output_val = pb_utils.get_output_tensor_by_name(
inference_response, out_col_name
).as_numpy()
outputs_dict[out_col_name] = output_val

return type(transformable)(outputs_dict)

@property
def exportable_backends(self):
return ["ensemble", "executor"]

def export(
self,
path: str,
input_schema: Schema,
output_schema: Schema,
params: dict = None,
node_id: int = None,
version: int = 1,
backend: str = "ensemble",
):
"""Create a directory inside supplied path based on our export name"""
# Export Triton TF back-end directory and config etc
export_name = self.__class__.__name__.lower()
node_name = f"{node_id}_{export_name}" if node_id is not None else export_name

node_export_path = pathlib.Path(path) / node_name
node_export_path.mkdir(exist_ok=True)

tf_model_path = pathlib.Path(node_export_path) / str(version) / "model.savedmodel"

if self.path:
copytree(
str(self.path),
tf_model_path,
dirs_exist_ok=True,
)
else:
self.model.save(tf_model_path, include_optimizer=False)

self.set_tf_model_name(node_name)
backend_model_config = self._export_model_config(node_name, node_export_path)
return backend_model_config
output = self.model(transformable)
# TODO: map output schema names to outputs produced by prediction
return type(transformable)({"output": output})

@property
def export_name(self):
Expand Down Expand Up @@ -180,45 +125,6 @@ def compute_output_schema(
"""
return self.output_schema

def _export_model_config(self, name, output_path):
"""Exports a TensorFlow model for serving with Triton
Parameters
----------
model:
The tensorflow model that should be served
name:
The name of the triton model to export
output_path:
The path to write the exported model to
"""
config = model_config.ModelConfig(
name=name, backend="tensorflow", platform="tensorflow_savedmodel"
)

config.parameters["TF_GRAPH_TAG"].string_value = "serve"
config.parameters["TF_SIGNATURE_DEF"].string_value = "serving_default"

for _, col_schema in self.input_schema.column_schemas.items():
add_model_param(
config.input,
model_config.ModelInput,
col_schema,
compute_dims(col_schema, self.scalar_shape),
)

for _, col_schema in self.output_schema.column_schemas.items():
add_model_param(
config.output,
model_config.ModelOutput,
col_schema,
compute_dims(col_schema, self.scalar_shape),
)

with open(os.path.join(output_path, "config.pbtxt"), "w", encoding="utf-8") as o:
text_format.PrintMessage(config, o)
return config

def _construct_schemas_from_model(self, model):
signatures = getattr(model, "signatures", {}) or {}
default_signature = signatures.get("serving_default")
Expand Down
1 change: 1 addition & 0 deletions merlin/systems/dag/runtimes/base_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(self, executor=None):
The Graph Executor to use to use for the transform, by default None
"""
self.executor = executor or LocalExecutor()
self.op_table = {}

def transform(self, graph: Graph, transformable: Transformable):
"""Run the graph with the input data.
Expand Down
20 changes: 20 additions & 0 deletions merlin/systems/dag/runtimes/triton/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# flake8: noqa
from merlin.systems.dag.runtimes.triton.runtime import ( # noqa
TritonEnsembleRuntime,
TritonExecutorRuntime,
)
15 changes: 15 additions & 0 deletions merlin/systems/dag/runtimes/triton/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
55 changes: 55 additions & 0 deletions merlin/systems/dag/runtimes/triton/ops/operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import List

from merlin.systems.dag.ops.operator import InferenceOperator


class TritonOperator:
"""Base class for Triton operators."""

def __init__(self, base_op: InferenceOperator):
"""Construct TritonOperator from a base operator.
Parameters
----------
base_op : merlin.systems.dag.ops.operator.InfereneOperator
Base operator used to construct this Triton op.
"""
self.op = base_op

@property
def export_name(self):
"""
Provides a clear common english identifier for this operator.
Returns
-------
String
Name of the current class as spelled in module.
"""
return self.__class__.__name__.lower()

@property
def exportable_backends(self) -> List[str]:
"""Returns list of supported backends.
Returns
-------
List[str]
List of supported backends
"""
return ["ensemble", "executor"]
Loading

0 comments on commit d2737da

Please sign in to comment.