diff --git a/merlin/systems/dag/ops/softmax_sampling.py b/merlin/systems/dag/ops/softmax_sampling.py index d10547179..8c8e3307d 100644 --- a/merlin/systems/dag/ops/softmax_sampling.py +++ b/merlin/systems/dag/ops/softmax_sampling.py @@ -101,8 +101,14 @@ def compute_output_schema( """Describe the operator's outputs""" return Schema( [ - ColumnSchema("ordered_ids", dtype=np.int32, dims=(None, 1)), - ColumnSchema("ordered_scores", dtype=np.float32, dims=(None, 1)), + ColumnSchema( + "ordered_ids", dtype=input_schema[self._input_col_name].dtype, dims=(None, 1) + ), + ColumnSchema( + "ordered_scores", + dtype=input_schema[self._relevance_col_name].dtype, + dims=(None, 1), + ), ] ) diff --git a/tests/unit/systems/dag/ops/test_ops.py b/tests/unit/systems/dag/ops/test_ops.py index 76fa0a9b4..6f05af695 100644 --- a/tests/unit/systems/dag/ops/test_ops.py +++ b/tests/unit/systems/dag/ops/test_ops.py @@ -45,8 +45,47 @@ def test_compute_dims(column_schema, expected_dims): assert compute_dims(column_schema) == expected_dims +@pytest.mark.parametrize( + ["id_dtype", "score_dtype"], + [ + ("int32", "float32"), + ("int64", "float64"), + ], +) +def test_softmax_sampling(id_dtype, score_dtype): + input_schema = Schema( + [ + ColumnSchema("movie_ids", dtype=id_dtype, dims=(None, 100)), + ColumnSchema("relevance_score", dtype=score_dtype, dims=(None, 100)), + ] + ) + + movie_ids = np.array(random.sample(range(10000), 100), dtype=id_dtype) + relevance_score = np.random.random(100).astype(score_dtype) + + combined_features = { + "movie_ids": np.expand_dims(movie_ids, axis=0), + "relevance_score": np.expand_dims(relevance_score, axis=0), + } + + input_table = TensorTable(combined_features) + + ordering = ["movie_ids"] >> SoftmaxSampling( + relevance_col="relevance_score", topk=10, temperature=20.0 + ) + + ensemble = Ensemble(ordering, input_schema) + output_table = ensemble.transform(input_table) + + assert output_table["ordered_ids"].dtype.to_numpy == input_schema["movie_ids"].dtype.to_numpy + assert ( + output_table["ordered_scores"].dtype.to_numpy + == input_schema["relevance_score"].dtype.to_numpy + ) + + @pytest.mark.skipif(not TRITON_SERVER_PATH, reason="triton server not found") -def test_softmax_sampling(tmpdir): +def test_softmax_sampling_with_triton(tmpdir): request_schema = Schema( [ ColumnSchema("movie_ids", dtype=np.int32, dims=(None, 100)),