Skip to content

Commit

Permalink
Raise TritonModelException if the Triton model has an error (#333)
Browse files Browse the repository at this point in the history
* Raise TritonModelException if the PyTorch Triton model has an error

* Raise from exc in executor_model

* Raise RuntimeError after every InferenceRequest if there is an error

* Add test of PredictPyTorch with Triton

* Apply pre-commit formatting after merge

---------

Co-authored-by: Karl Higley <[email protected]>
Co-authored-by: Karl Higley <[email protected]>
  • Loading branch information
3 people authored Apr 21, 2023
1 parent 20d1242 commit e94d2a9
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 7 deletions.
4 changes: 4 additions & 0 deletions merlin/systems/dag/runtimes/triton/ops/fil.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,10 @@ def transform(
self.fil_model_name, inputs, input_schema, output_schema
)
inference_response = inference_request.exec()

if inference_response.has_error():
raise RuntimeError(str(inference_response.error().message()))

return triton_response_to_tensor_table(inference_response, type(inputs), output_schema)


Expand Down
3 changes: 3 additions & 0 deletions merlin/systems/dag/runtimes/triton/ops/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ def transform(self, col_selector: ColumnSelector, transformable: Transformable):

inference_response = inference_request.exec()

if inference_response.has_error():
raise RuntimeError(str(inference_response.error().message()))

return triton_response_to_tensor_table(
inference_response, type(transformable), self.output_schema
)
Expand Down
3 changes: 3 additions & 0 deletions merlin/systems/dag/runtimes/triton/ops/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def transform(self, col_selector: ColumnSelector, transformable: Transformable):
)
inference_response = inference_request.exec()

if inference_response.has_error():
raise RuntimeError(inference_response.error().message())

# TODO: Validate that the outputs match the schema
return triton_response_to_tensor_table(
inference_response, type(transformable), self.output_schema
Expand Down
7 changes: 1 addition & 6 deletions merlin/systems/dag/runtimes/triton/ops/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from shutil import copyfile

import tritonclient.grpc.model_config_pb2 as model_config
import tritonclient.utils
from google.protobuf import text_format

from merlin.core.protocols import Transformable
Expand Down Expand Up @@ -89,12 +88,8 @@ def transform(self, col_selector: ColumnSelector, transformable: Transformable):

inference_response = inference_request.exec()

# check inference response for errors:
if inference_response.has_error():
# Cannot raise inference response error because it is not derived from BaseException
raise tritonclient.utils.InferenceServerException(
str(inference_response.error().message())
)
raise RuntimeError(inference_response.error().message())

response_table = triton_response_to_tensor_table(
inference_response, type(transformable), self.output_schema
Expand Down
7 changes: 6 additions & 1 deletion merlin/systems/triton/models/executor_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import pathlib
from pathlib import Path

import triton_python_backend_utils as pb_utils

from merlin.dag import postorder_iter_nodes
from merlin.systems.dag import Ensemble
from merlin.systems.dag.runtimes.triton import TritonExecutorRuntime
Expand Down Expand Up @@ -93,7 +95,10 @@ def execute(self, request):
be the same as `requests`
"""
inputs = triton_request_to_tensor_table(request, self.ensemble.input_schema)
outputs = self.ensemble.transform(inputs, runtime=TritonExecutorRuntime())
try:
outputs = self.ensemble.transform(inputs, runtime=TritonExecutorRuntime())
except Exception as exc:
raise pb_utils.TritonModelException(str(exc)) from exc
return tensor_table_to_triton_response(outputs, self.ensemble.output_schema)


Expand Down
97 changes: 97 additions & 0 deletions tests/unit/systems/ops/torch/test_ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
#
# Copyright (c) 2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import shutil

import numpy as np
import pandas as pd
import pytest
import tritonclient.utils

from merlin.schema import ColumnSchema, Schema
from merlin.systems.dag.ensemble import Ensemble
from merlin.systems.dag.ops.pytorch import PredictPyTorch
from merlin.systems.triton.utils import run_ensemble_on_tritonserver

torch = pytest.importorskip("torch")

TRITON_SERVER_PATH = shutil.which("tritonserver")


@pytest.mark.skipif(not TRITON_SERVER_PATH, reason="triton server not found")
def test_model_in_ensemble(tmpdir):
class MyModel(torch.nn.Module):
def forward(self, x):
v = torch.stack(list(x.values())).sum(axis=0)
return v

model = MyModel()

traced_model = torch.jit.trace(model, {"a": torch.tensor(1), "b": torch.tensor(2)}, strict=True)

model_input_schema = Schema(
[ColumnSchema("a", dtype="int64"), ColumnSchema("b", dtype="int64")]
)
model_output_schema = Schema([ColumnSchema("output", dtype="int64")])

model_node = model_input_schema.column_names >> PredictPyTorch(
traced_model, model_input_schema, model_output_schema
)

ensemble = Ensemble(model_node, model_input_schema)

ensemble_config, _ = ensemble.export(str(tmpdir))

df = pd.DataFrame({"a": [1], "b": [2]})

response = run_ensemble_on_tritonserver(
str(tmpdir), model_input_schema, df, ["output"], ensemble_config.name
)
np.testing.assert_array_equal(response["output"], np.array([3]))


@pytest.mark.skipif(not TRITON_SERVER_PATH, reason="triton server not found")
def test_model_error(tmpdir):
class MyModel(torch.nn.Module):
def forward(self, x):
v = torch.stack(list(x.values())).sum()
return v

model = MyModel()

traced_model = torch.jit.trace(model, {"a": torch.tensor(1), "b": torch.tensor(2)}, strict=True)

model_input_schema = Schema([ColumnSchema("a", dtype="int64")])
model_output_schema = Schema([ColumnSchema("output", dtype="int64")])

model_node = model_input_schema.column_names >> PredictPyTorch(
traced_model, model_input_schema, model_output_schema
)

ensemble = Ensemble(model_node, model_input_schema)

ensemble_config, _ = ensemble.export(str(tmpdir))

# run inference with missing input (that was present when model was compiled)
# we're expecting a KeyError at runtime.
df = pd.DataFrame({"a": [1]})

with pytest.raises(tritonclient.utils.InferenceServerException) as exc_info:
run_ensemble_on_tritonserver(
str(tmpdir), model_input_schema, df, ["output"], ensemble_config.name
)
assert "The following operation failed in the TorchScript interpreter" in str(exc_info.value)
assert "RuntimeError: KeyError: b" in str(exc_info.value)

0 comments on commit e94d2a9

Please sign in to comment.