Skip to content

Commit

Permalink
Refactor tensor representation adjustments to align with a schema (#332)
Browse files Browse the repository at this point in the history
* Rework `TensorTable`/Triton request/response methods to align w/ schemas

* Adjust `WorkflowRunner` to use `align_with_schema`

* Clean up and add docstrings
  • Loading branch information
karlhigley authored Apr 19, 2023
1 parent d652251 commit 20d1242
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 108 deletions.
6 changes: 4 additions & 2 deletions merlin/systems/dag/runtimes/triton/ops/fil.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,14 @@ def transform(
)

inputs = TensorTable({"input__0": input0})
input_schema = Schema(["input__0"])
output_schema = Schema(["output__0"])

inference_request = tensor_table_to_triton_request(
self.fil_model_name, inputs, ["input__0"], ["output__0"]
self.fil_model_name, inputs, input_schema, output_schema
)
inference_response = inference_request.exec()
return triton_response_to_tensor_table(inference_response, type(inputs), ["output__0"])
return triton_response_to_tensor_table(inference_response, type(inputs), output_schema)


class FILTriton(TritonOperator):
Expand Down
6 changes: 3 additions & 3 deletions merlin/systems/dag/runtimes/triton/ops/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,14 @@ def transform(self, col_selector: ColumnSelector, transformable: Transformable):
inference_request = tensor_table_to_triton_request(
self.torch_model_name,
transformable,
self.input_schema.column_names,
self.output_schema.column_names,
self.input_schema,
self.output_schema,
)

inference_response = inference_request.exec()

return triton_response_to_tensor_table(
inference_response, type(transformable), self.output_schema.column_names
inference_response, type(transformable), self.output_schema
)

def compute_input_schema(
Expand Down
6 changes: 3 additions & 3 deletions merlin/systems/dag/runtimes/triton/ops/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,14 @@ def transform(self, col_selector: ColumnSelector, transformable: Transformable):
inference_request = tensor_table_to_triton_request(
self.tf_model_name,
transformable,
self.input_schema.column_names,
self.output_schema.column_names,
self.input_schema,
self.output_schema,
)
inference_response = inference_request.exec()

# TODO: Validate that the outputs match the schema
return triton_response_to_tensor_table(
inference_response, type(transformable), self.output_schema.column_names
inference_response, type(transformable), self.output_schema
)

def export(
Expand Down
14 changes: 3 additions & 11 deletions merlin/systems/dag/runtimes/triton/ops/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,19 +80,11 @@ def transform(self, col_selector: ColumnSelector, transformable: Transformable):
TensorTable
Returns a transformed dataframe for this operator"""

output_names = []
for col in self.output_schema:
if col.is_ragged:
output_names.append(f"{col.name}__values")
output_names.append(f"{col.name}__offsets")
else:
output_names.append(col.name)

inference_request = tensor_table_to_triton_request(
self._nvt_model_name,
transformable,
self.input_schema.column_names,
output_names,
self.input_schema,
self.output_schema,
)

inference_response = inference_request.exec()
Expand All @@ -105,7 +97,7 @@ def transform(self, col_selector: ColumnSelector, transformable: Transformable):
)

response_table = triton_response_to_tensor_table(
inference_response, type(transformable), output_names
inference_response, type(transformable), self.output_schema
)

return response_table
Expand Down
151 changes: 100 additions & 51 deletions merlin/systems/triton/conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import itertools
from typing import Any, Dict, List

import numpy as np
import pandas as pd
Expand All @@ -33,11 +34,93 @@
from merlin.core.compat import cupy as cp
from merlin.core.dispatch import build_cudf_list_column, is_list_dtype
from merlin.dag import Supports
from merlin.schema import Schema
from merlin.systems.dag.ops.compat import pb_utils
from merlin.table import TensorTable


def triton_request_to_tensor_table(request, column_names):
def tensor_names(schema: Schema) -> List[str]:
"""
Compute the expected tensor names from a Merlin schema
This takes the columns from a schema, checks whether the columns are ragged or not,
and translates ragged columns to two separate tensor names for the values/offsets
representation.
Parameters
----------
schema : Schema
Schema to compute tensor names for
Returns
-------
List[str]
A list of the tensors implied by the schema
"""
tensor_names = []
for col_name, col_schema in schema.column_schemas.items():
if col_schema.is_ragged:
tensor_names.append(f"{col_name}__values")
tensor_names.append(f"{col_name}__offsets")
else:
tensor_names.append(col_name)
return tensor_names


def match_representations(schema: Schema, dict_array: Dict[str, Any]) -> Dict[str, Any]:
"""
Convert values-only tensors to values/offsets when indicated by the schema
Parameters
----------
schema : Schema
Downstream input schema to match
dict_array : Dict[str, Any]
A dictionary of NumPy or CuPy ndarrays
Returns
-------
Dict[str, Any]
A dictionary of NumPy or CuPy ndarrays with representations adjusted
"""
schema_names = tensor_names(schema)

aligned = {}
for tensor_name in dict_array.keys():
if tensor_name in schema_names:
aligned[tensor_name] = dict_array[tensor_name]
else:
# Ragged columns with fixed shape values
values, offsets = _to_values_offsets(dict_array[tensor_name])
aligned[f"{tensor_name}__values"] = values
aligned[f"{tensor_name}__offsets"] = offsets

return aligned


def _to_values_offsets(array):
"""Convert array to values/offsets representation
Parameters
----------
array : numpy.ndarray or cupy.ndarray
Array to convert
Returns
-------
values, offsets
Tuple of values and offsets
"""
num_rows = array.shape[0]
row_lengths = [array.shape[1]] * num_rows
offsets = [0] + list(itertools.accumulate(row_lengths))
array_lib = cp if cp and isinstance(array, cp.ndarray) else np
offsets = array_lib.array(offsets, dtype="int32")
values = array.reshape(-1, *array.shape[2:])
return values, offsets


def triton_request_to_tensor_table(request, schema):
"""
Turns a Triton request into a TensorTable by extracting individual tensors
from the request using pb_utils.
Expand All @@ -54,19 +137,12 @@ def triton_request_to_tensor_table(request, column_names):
TensorTable
Dictionary-like representation of the input columns
"""
dict_inputs = {}
for name in column_names:
try:
values = _array_from_triton_tensor(request, f"{name}__values")
lengths = _array_from_triton_tensor(request, f"{name}__offsets")
dict_inputs[name] = (values, lengths)
except (AttributeError, ValueError):
dict_inputs[name] = _array_from_triton_tensor(request, name)

return TensorTable(dict_inputs)
return TensorTable(
{name: _array_from_triton_tensor(request, name) for name in tensor_names(schema)}
)


def tensor_table_to_triton_response(tensor_table):
def tensor_table_to_triton_response(tensor_table, schema):
"""
Turns a TensorTable into a Triton response that can be returned
to resolve an incoming request.
Expand All @@ -81,20 +157,13 @@ def tensor_table_to_triton_response(tensor_table):
response : TritonInferenceResponse
The output response for predictions
"""
output_tensors = []
for name, column in tensor_table.items():
if column.offsets is not None:
values = _triton_tensor_from_array(f"{name}__values", column.values)
offsets = _triton_tensor_from_array(f"{name}__offsets", column.offsets)
output_tensors.extend([values, offsets])
else:
col_tensor = _triton_tensor_from_array(name, column.values)
output_tensors.append(col_tensor)

return pb_utils.InferenceResponse(output_tensors)
aligned = match_representations(schema, tensor_table.to_dict())
return pb_utils.InferenceResponse(
[_triton_tensor_from_array(name, array) for name, array in aligned.items()]
)


def tensor_table_to_triton_request(model_name, tensor_table, input_col_names, output_col_names):
def tensor_table_to_triton_request(model_name, tensor_table, input_schema, output_schema):
"""
Turns a TensorTable into a Triton request that can, for example, be used to make a
Business Logic Scripting call to a Triton model on the same Triton instance.
Expand All @@ -115,26 +184,17 @@ def tensor_table_to_triton_request(model_name, tensor_table, input_col_names, ou
TritonInferenceRequest
The TensorTable reformatted as a Triton request
"""
input_tensors = []

for name, column in tensor_table.items():
if name in input_col_names:
if column.offsets is not None:
values = _triton_tensor_from_array(f"{name}__values", column.values)
offsets = _triton_tensor_from_array(f"{name}__offsets", column.offsets)
input_tensors.extend([values, offsets])
else:
col_tensor = _triton_tensor_from_array(name, column.values)
input_tensors.append(col_tensor)
aligned = match_representations(input_schema, tensor_table.to_dict())
input_tensors = [_triton_tensor_from_array(name, tensor) for name, tensor in aligned.items()]

return pb_utils.InferenceRequest(
model_name=model_name,
requested_output_names=output_col_names,
requested_output_names=tensor_names(output_schema),
inputs=input_tensors,
)


def triton_response_to_tensor_table(response, transformable_type, output_column_names):
def triton_response_to_tensor_table(response, transformable_type, schema):
"""
Turns a Triton response into a TensorTable by extracting individual tensors
from the request using pb_utils.
Expand All @@ -153,20 +213,9 @@ def triton_response_to_tensor_table(response, transformable_type, output_column_
Transformable
A TensorTable or DataFrame representing the response columns from a Triton request
"""
outputs_dict = {}

for out_col_name in output_column_names:
try:
values = _array_from_triton_tensor(response, f"{out_col_name}__values")
lengths = _array_from_triton_tensor(response, f"{out_col_name}__offsets")
outputs_dict[out_col_name] = (values, lengths)
except (AttributeError, ValueError):
outputs_dict[out_col_name] = _array_from_triton_tensor(response, out_col_name)

output_val = _array_from_triton_tensor(response, out_col_name)
outputs_dict[out_col_name] = output_val

return transformable_type(outputs_dict)
return transformable_type(
{name: _array_from_triton_tensor(response, name) for name in tensor_names(schema)}
)


def _triton_tensor_from_array(name, array):
Expand Down
4 changes: 2 additions & 2 deletions merlin/systems/triton/models/executor_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ def execute(self, request):
A list of pb_utils.InferenceResponse. The length of this list must
be the same as `requests`
"""
inputs = triton_request_to_tensor_table(request, self.ensemble.input_schema.column_names)
inputs = triton_request_to_tensor_table(request, self.ensemble.input_schema)
outputs = self.ensemble.transform(inputs, runtime=TritonExecutorRuntime())
return tensor_table_to_triton_response(outputs)
return tensor_table_to_triton_response(outputs, self.ensemble.output_schema)


def _parse_model_repository(model_repository: str) -> str:
Expand Down
40 changes: 4 additions & 36 deletions merlin/systems/workflow/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,14 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import functools
import itertools
import json
import logging

import numpy as np

from merlin.core.compat import cupy
from merlin.core.dispatch import concat_columns
from merlin.dag import ColumnSelector, Supports
from merlin.schema import Tags
from merlin.systems.triton.conversions import convert_format
from merlin.table import TensorColumn, TensorTable
from merlin.systems.triton.conversions import convert_format, match_representations
from merlin.table import TensorTable

LOG = logging.getLogger("merlin-systems")

Expand Down Expand Up @@ -109,14 +105,8 @@ def run_workflow(self, input_tensors):
if kind != Supports.CPU_DICT_ARRAY:
transformed, kind = convert_format(transformed, kind, Supports.CPU_DICT_ARRAY)

output_table = TensorTable(transformed)

for col in self.workflow.output_schema:
if col.is_ragged and output_table[col.name].offsets is None:
values, offsets = _to_ragged(output_table[col.name].values)
output_table[col.name] = TensorColumn(values, offsets=offsets)

output_dict = output_table.to_dict()
transformed = TensorTable(transformed).to_dict()
output_dict = match_representations(self.workflow.output_schema, transformed)

for key, value in output_dict.items():
output_dict[key] = value.astype(self.output_dtypes[key])
Expand Down Expand Up @@ -205,25 +195,3 @@ def _get_param(self, config, *args, default=None):
for key in args:
config_element = config_element.get(key, {})
return config_element or default


def _to_ragged(array):
"""Convert Array to Ragged representation
Parameters
----------
array : numpy.ndarray or cupy.ndarray
Array to convert
Returns
-------
values, offsets
Tuple of values and offsets
"""
num_rows = array.shape[0]
row_lengths = [array.shape[1]] * num_rows
offsets = [0] + list(itertools.accumulate(row_lengths))
array_lib = cupy if cupy and isinstance(array, cupy.ndarray) else np
offsets = array_lib.array(offsets, dtype="int32")
values = array.reshape(-1, *array.shape[2:])
return values, offsets

0 comments on commit 20d1242

Please sign in to comment.