Skip to content

Commit

Permalink
Return correct dtype for offsets in TransformWorkflow Operator (#361)
Browse files Browse the repository at this point in the history
* Add assertion for expected dtypes in workflow tests

* Coerce offsets to int32 in match_representations

* Re-format test_ensemble.py
  • Loading branch information
oliverholworthy authored Jun 6, 2023
1 parent d763e7b commit ed7009c
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 0 deletions.
2 changes: 2 additions & 0 deletions merlin/systems/triton/conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ def match_representations(schema: Schema, dict_array: Dict[str, Any]) -> Dict[st

if dtype != md.unknown:
aligned[vals_name] = aligned[vals_name].astype(dtype.to_numpy)

aligned[offs_name] = aligned[offs_name].astype("int32")
else:
try:
# Look for values and offsets that already exist,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def test_workflow_with_ragged_output(tmpdir):
)
for key, value in expected_response.items():
np.testing.assert_array_equal(response[key], value)
assert response[key].dtype == value.dtype


@pytest.mark.skipif(not TRITON_SERVER_PATH, reason="triton server not found")
Expand Down Expand Up @@ -246,6 +247,7 @@ def test_workflow_with_padded_output(tmpdir):
)
for key, value in expected_response.items():
np.testing.assert_array_equal(response[key], value)
assert response[key].dtype == value.dtype


@pytest.mark.skipif(not TRITON_SERVER_PATH, reason="triton server not found")
Expand Down Expand Up @@ -306,6 +308,7 @@ def test_workflow_with_ragged_input_and_output(tmpdir):
)
for key, value in expected_response.items():
np.testing.assert_array_equal(response[key], value)
assert response[key].dtype == value.dtype


@pytest.mark.skipif(not TRITON_SERVER_PATH, reason="triton server not found")
Expand Down Expand Up @@ -374,3 +377,4 @@ def test_workflow_dtypes(tmpdir):
)
for key, value in expected_response.items():
np.testing.assert_array_equal(response[key], value)
assert response[key].dtype == value.dtype

0 comments on commit ed7009c

Please sign in to comment.