Skip to content

Commit

Permalink
added unit test for dataloader seed
Browse files Browse the repository at this point in the history
  • Loading branch information
dgourab-aws committed Dec 26, 2024
1 parent c84cb33 commit 89f54e7
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions axlearn/common/seed_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import os
from absl.testing import absltest
from axlearn.common import test_utils
import tensorflow as tf
import importlib


class SeedTest(test_utils.TestCase):
def test_tf_random_seed_from_env(self):
os.environ["DATA_SEED"] = "42"
importlib.import_module('axlearn.common.input_lm')
sequence_1 = [tf.random.uniform((1,)).numpy() for _ in range(5)]

# Re-import 'input_lm' to reset seed
importlib.reload(importlib.import_module('axlearn.common.input_lm'))
sequence_2 = [tf.random.uniform((1,)).numpy() for _ in range(5)]

# Assert that the two sequences are same
self.assertEqual(sequence_1, sequence_2)

if __name__ == "__main__":
absltest.main()

0 comments on commit 89f54e7

Please sign in to comment.