Skip to content

Commit

Permalink
Update SoftmaxSampling to support dynamic dtypes (#304)
Browse files Browse the repository at this point in the history
The output types match the input dtypes passed. And the output schema
is computed from the input types passed.
  • Loading branch information
oliverholworthy authored Mar 24, 2023
1 parent cb7e251 commit d1d270e
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 3 deletions.
10 changes: 8 additions & 2 deletions merlin/systems/dag/ops/softmax_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,14 @@ def compute_output_schema(
"""Describe the operator's outputs"""
return Schema(
[
ColumnSchema("ordered_ids", dtype=np.int32, dims=(None, 1)),
ColumnSchema("ordered_scores", dtype=np.float32, dims=(None, 1)),
ColumnSchema(
"ordered_ids", dtype=input_schema[self._input_col_name].dtype, dims=(None, 1)
),
ColumnSchema(
"ordered_scores",
dtype=input_schema[self._relevance_col_name].dtype,
dims=(None, 1),
),
]
)

Expand Down
41 changes: 40 additions & 1 deletion tests/unit/systems/dag/ops/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,47 @@ def test_compute_dims(column_schema, expected_dims):
assert compute_dims(column_schema) == expected_dims


@pytest.mark.parametrize(
["id_dtype", "score_dtype"],
[
("int32", "float32"),
("int64", "float64"),
],
)
def test_softmax_sampling(id_dtype, score_dtype):
input_schema = Schema(
[
ColumnSchema("movie_ids", dtype=id_dtype, dims=(None, 100)),
ColumnSchema("relevance_score", dtype=score_dtype, dims=(None, 100)),
]
)

movie_ids = np.array(random.sample(range(10000), 100), dtype=id_dtype)
relevance_score = np.random.random(100).astype(score_dtype)

combined_features = {
"movie_ids": np.expand_dims(movie_ids, axis=0),
"relevance_score": np.expand_dims(relevance_score, axis=0),
}

input_table = TensorTable(combined_features)

ordering = ["movie_ids"] >> SoftmaxSampling(
relevance_col="relevance_score", topk=10, temperature=20.0
)

ensemble = Ensemble(ordering, input_schema)
output_table = ensemble.transform(input_table)

assert output_table["ordered_ids"].dtype.to_numpy == input_schema["movie_ids"].dtype.to_numpy
assert (
output_table["ordered_scores"].dtype.to_numpy
== input_schema["relevance_score"].dtype.to_numpy
)


@pytest.mark.skipif(not TRITON_SERVER_PATH, reason="triton server not found")
def test_softmax_sampling(tmpdir):
def test_softmax_sampling_with_triton(tmpdir):
request_schema = Schema(
[
ColumnSchema("movie_ids", dtype=np.int32, dims=(None, 100)),
Expand Down

0 comments on commit d1d270e

Please sign in to comment.