Skip to content

Commit

Permalink
Changed to global seed
Browse files Browse the repository at this point in the history
  • Loading branch information
dgourab committed Dec 16, 2024
1 parent 0252e81 commit 6f5ae2d
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 27 deletions.
11 changes: 5 additions & 6 deletions axlearn/common/input_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
from enum import Enum
from typing import Any, Literal, Optional

import os

import seqio
import tensorflow as tf
from absl import logging
Expand All @@ -23,6 +21,9 @@
# Value of "target_labels" which will be ignored in seq2seq processing.
SEQ2SEQ_IGNORE_TARGET_LABEL = -1

seed = os.environ.get("DATA_SEED")
seed = int(seed) if seed is not None else None
tf.random.set_seed(seed)

class InputDataType(Enum):
"""Represents input data types for decoder-only language model training.
Expand Down Expand Up @@ -228,9 +229,7 @@ def process(ds: tf.data.Dataset) -> tf.data.Dataset:
ds = ds.unbatch()
# Shuffle so that read order is not dominated by document order.
if shuffle_buffer_size > 0:
seed = os.environ.get("DATA_SEED")
seed = int(seed) if seed is not None else None
ds = ds.shuffle(shuffle_buffer_size, reshuffle_each_iteration=True, seed = seed)
ds = ds.shuffle(shuffle_buffer_size, reshuffle_each_iteration=True)
return ds

return process
Expand Down Expand Up @@ -1179,4 +1178,4 @@ def pad_prefix():
# 2. Control max decode length by prefix.shape[-1].
config_for_function(pad_prefix),
is_training=False,
)
)
27 changes: 6 additions & 21 deletions axlearn/common/input_tf_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
from collections.abc import Mapping, Sequence
from typing import Any, Callable, Optional, Union

import os

import jax
import seqio
import tensorflow as tf
Expand Down Expand Up @@ -76,7 +74,6 @@ def tfds_read_config(
shard_index: Optional[int] = None,
read_parallelism: int = 1,
decode_parallelism: int = 32,
seed: Optional[int] = None,
) -> tfds.ReadConfig:
"""Constructs a ReadConfig for tfds dataset.
Expand All @@ -98,8 +95,6 @@ def tfds_read_config(
Returns:
A tfds.ReadConfig.
"""
seed = os.environ.get("DATA_SEED")
seed = int(seed) if seed is not None else None
num_shards = jax.process_count() if num_shards is None else num_shards
shard_index = jax.process_index() if shard_index is None else shard_index
num_parallel_calls_for_read = read_parallelism if is_training else 1
Expand All @@ -111,7 +106,6 @@ def tfds_read_config(
input_context=tf.distribute.InputContext(
num_input_pipelines=num_shards, input_pipeline_id=shard_index
),
shuffle_seed=seed,
)


Expand Down Expand Up @@ -275,14 +269,11 @@ def tfds_dataset(

if data_dir is None:
data_dir = get_data_dir()

seed = os.environ.get("DATA_SEED")
seed = int(seed) if seed is not None else None

if read_config is None:
read_config = config_for_function(tfds_read_config).set(is_training=is_training, seed = seed)
read_config = config_for_function(tfds_read_config).set(is_training=is_training)
else:
read_config = read_config.set(is_training=is_training, seed = seed)
read_config = read_config.set(is_training=is_training)

def fn() -> tf.data.Dataset:
local_read_config = read_config.clone()
Expand Down Expand Up @@ -318,9 +309,7 @@ def fn() -> tf.data.Dataset:
if shuffle_buffer_size > 0:
# Subsequent processing may merge/split examples (e.g. for T5), so shuffle examples
# during training before any processing.
seed = os.environ.get("DATA_SEED")
seed = int(seed) if seed is not None else None
ds = ds.shuffle(shuffle_buffer_size, reshuffle_each_iteration=True, seed = seed)
ds = ds.shuffle(shuffle_buffer_size, reshuffle_each_iteration=True)
return ds

return fn
Expand Down Expand Up @@ -380,9 +369,7 @@ def fn() -> tf.data.Dataset:
if shuffle_buffer_size > 0:
# Subsequent processing may merge/split examples (e.g. for T5), so shuffle examples
# during training before any processing.
seed = os.environ.get("DATA_SEED")
seed = int(seed) if seed is not None else None
ds = ds.shuffle(shuffle_buffer_size, reshuffle_each_iteration=True, seed = seed)
ds = ds.shuffle(shuffle_buffer_size, reshuffle_each_iteration=True)
return ds

return fn
Expand Down Expand Up @@ -1090,9 +1077,7 @@ def shuffle(shuffle_buffer_size: int) -> DatasetToDatasetFn:

def fn(ds: tf.data.Dataset) -> tf.data.Dataset:
if shuffle_buffer_size > 0:
seed = os.environ.get("DATA_SEED")
seed = int(seed) if seed is not None else None
ds = ds.shuffle(shuffle_buffer_size, reshuffle_each_iteration=True, seed = seed)
ds = ds.shuffle(shuffle_buffer_size, reshuffle_each_iteration=True)

return ds

Expand Down Expand Up @@ -1486,4 +1471,4 @@ def trim_and_pad_tensor(
t = tf.pad(t, [(0, 0)] * (len(t.shape) - 1) + [(0, pad_amt)], constant_values=pad_id)
t = tf.ensure_shape(t, t.shape[:-1] + [max_len])

return t
return t

0 comments on commit 6f5ae2d

Please sign in to comment.