diff --git a/tests/unit/torch/tabular/test_transformations.py b/tests/unit/torch/tabular/test_transformations.py index e3b2e00b47..5dc241888b 100644 --- a/tests/unit/torch/tabular/test_transformations.py +++ b/tests/unit/torch/tabular/test_transformations.py @@ -14,6 +14,7 @@ # limitations under the License. # +import numpy as np import pytest import torch from merlin.schema import Tags @@ -21,6 +22,9 @@ import transformers4rec.torch as tr from merlin_standard_lib import schema +np.random.seed(0) +torch.manual_seed(0) + @pytest.mark.parametrize("replacement_prob", [0.1, 0.3, 0.5, 0.7]) def test_stochastic_swap_noise(replacement_prob): diff --git a/transformers4rec/torch/utils/schema_utils.py b/transformers4rec/torch/utils/schema_utils.py index cf3d9d7f4f..2eed3de403 100644 --- a/transformers4rec/torch/utils/schema_utils.py +++ b/transformers4rec/torch/utils/schema_utils.py @@ -17,6 +17,7 @@ import random from typing import Any, Dict, Optional +import numpy as np import torch from merlin.schema.io.proto_utils import has_field @@ -32,9 +33,15 @@ def random_data_from_schema( min_session_length: int = 5, device=None, ragged=False, + seed=0, ) -> TabularData: data: Dict[str, Any] = {} + random.seed(seed) + np.random.seed(seed) + if seed: + torch.manual_seed(seed) + for i in range(num_rows): session_length = None if max_session_length: