Skip to content

Commit

Permalink
fix: changing all PRNGKey to key because its going to be deprecated soon
Browse files Browse the repository at this point in the history
  • Loading branch information
init-22 committed Dec 22, 2024
1 parent 53eff1d commit 23881fd
Show file tree
Hide file tree
Showing 45 changed files with 45 additions and 45 deletions.
2 changes: 1 addition & 1 deletion algorithmic_efficiency/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __eq__(self, other):
nn.Module]
ParameterTypeTree = Dict[ParameterKey, Dict[ParameterKey, ParameterType]]

RandomState = Any # Union[jax.random.PRNGKey, int, bytes, ...]
RandomState = Any # Union[jax.random.key, int, bytes, ...]

OptimizerState = Union[Dict[str, Any], Tuple[Any, Any]]
Hyperparameters = Any
Expand Down
2 changes: 1 addition & 1 deletion algorithmic_efficiency/workloads/fastmri/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def process_example(example_index, example):
process_rng, example_index)
else:
# NOTE(dsuo): we use fixed randomness for eval.
process_rng = tf.cast(jax.random.PRNGKey(_EVAL_SEED), tf.int64)
process_rng = tf.cast(jax.random.key(_EVAL_SEED), tf.int64)
return _process_example(*example, process_rng)

ds = ds.enumerate().map(process_example, num_parallel_calls=16)
Expand Down
2 changes: 1 addition & 1 deletion algorithmic_efficiency/workloads/mnist/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def _normalize(image: spec.Tensor, mean: float, stddev: float) -> spec.Tensor:


def _build_mnist_dataset(
data_rng: jax.random.PRNGKey,
data_rng: jax.random.key,
num_train_examples: int,
num_validation_examples: int,
train_mean: float,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def loss_fn(
return loss_dict

def _build_input_queue(self,
data_rng: jax.random.PRNGKey,
data_rng: jax.random.key,
split: str,
data_dir: str,
global_batch_size: int):
Expand Down
2 changes: 1 addition & 1 deletion algorithmic_efficiency/workloads/ogbg/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def eval_period_time_sec(self) -> int:
return 4 * 60

def _build_input_queue(self,
data_rng: jax.random.PRNGKey,
data_rng: jax.random.key,
split: str,
data_dir: str,
global_batch_size: int):
Expand Down
2 changes: 1 addition & 1 deletion algorithmic_efficiency/workloads/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ def print_jax_model_summary(model, fake_inputs):
"""Prints a summary of the jax module."""
tabulate_fn = nn.tabulate(
model,
jax.random.PRNGKey(0),
jax.random.key(0),
console_kwargs={
'force_terminal': False, 'force_jupyter': False, 'width': 240
},
Expand Down
2 changes: 1 addition & 1 deletion algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def initialize_cache(self,
config = models.TransformerConfig(deterministic=True, decode=True)
target_shape = (inputs.shape[0], max_decode_len) + inputs.shape[2:]
initial_variables = models.Transformer(config).init(
jax.random.PRNGKey(0),
jax.random.key(0),
jnp.ones(inputs.shape, jnp.float32),
jnp.ones(target_shape, jnp.float32))
return initial_variables['cache']
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def model_fn(
return logits_batch, None

def _build_input_queue(self,
data_rng: jax.random.PRNGKey,
data_rng: jax.random.key,
split: str,
data_dir: str,
global_batch_size: int,
Expand Down
2 changes: 1 addition & 1 deletion algorithmic_efficiency/workloads/wmt/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def glu(self) -> bool:
return False

def _build_input_queue(self,
data_rng: jax.random.PRNGKey,
data_rng: jax.random.key,
split: str,
data_dir: str,
global_batch_size: int,
Expand Down
2 changes: 1 addition & 1 deletion docker/scripts/check_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

print('JAX identified %d GPU devices' % jax.local_device_count())
print('Generating RNG seed for CUDA sanity check ... ')
rng = jax.random.PRNGKey(0)
rng = jax.random.key(0)
data_rng, shuffle_rng = jax.random.split(rng, 2)

if jax.local_device_count() == 8 and data_rng is not None:
Expand Down
2 changes: 1 addition & 1 deletion tests/modeldiffs/criteo1tb/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def sd_transform(sd):
jax_model_kwargs = dict(
augmented_and_preprocessed_input_batch=jax_batch,
mode=spec.ForwardPassMode.EVAL,
rng=jax.random.PRNGKey(0),
rng=jax.random.key(0),
update_batch_norm=False)

out_diff(
Expand Down
2 changes: 1 addition & 1 deletion tests/modeldiffs/criteo1tb_embed_init/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def sd_transform(sd):
jax_model_kwargs = dict(
augmented_and_preprocessed_input_batch=jax_batch,
mode=spec.ForwardPassMode.EVAL,
rng=jax.random.PRNGKey(0),
rng=jax.random.key(0),
update_batch_norm=False)

out_diff(
Expand Down
2 changes: 1 addition & 1 deletion tests/modeldiffs/criteo1tb_layernorm/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def sd_transform(sd):
jax_model_kwargs = dict(
augmented_and_preprocessed_input_batch=jax_batch,
# mode=spec.ForwardPassMode.EVAL,
rng=jax.random.PRNGKey(0),
rng=jax.random.key(0),
update_batch_norm=False)

out_diff(
Expand Down
2 changes: 1 addition & 1 deletion tests/modeldiffs/criteo1tb_resnet/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def sd_transform(sd):
jax_model_kwargs = dict(
augmented_and_preprocessed_input_batch=jax_batch,
mode=spec.ForwardPassMode.EVAL,
rng=jax.random.PRNGKey(0),
rng=jax.random.key(0),
update_batch_norm=False)

out_diff(
Expand Down
2 changes: 1 addition & 1 deletion tests/modeldiffs/diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def torch2jax(jax_workload,
key_transform=None,
sd_transform=None,
init_kwargs=dict(dropout_rate=0.0, aux_dropout_rate=0.0)):
jax_params, model_state = jax_workload.init_model_fn(jax.random.PRNGKey(0),
jax_params, model_state = jax_workload.init_model_fn(jax.random.key(0),
**init_kwargs)
pytorch_model, _ = pytorch_workload.init_model_fn([0], **init_kwargs)
jax_params = jax_utils.unreplicate(jax_params).unfreeze()
Expand Down
2 changes: 1 addition & 1 deletion tests/modeldiffs/fastmri/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def sort_key(k):
jax_model_kwargs = dict(
augmented_and_preprocessed_input_batch=jax_batch,
mode=spec.ForwardPassMode.EVAL,
rng=jax.random.PRNGKey(0),
rng=jax.random.key(0),
update_batch_norm=False)

out_diff(
Expand Down
2 changes: 1 addition & 1 deletion tests/modeldiffs/fastmri_layernorm/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def sort_key(k):
jax_model_kwargs = dict(
augmented_and_preprocessed_input_batch=jax_batch,
mode=spec.ForwardPassMode.EVAL,
rng=jax.random.PRNGKey(0),
rng=jax.random.key(0),
update_batch_norm=False)

out_diff(
Expand Down
2 changes: 1 addition & 1 deletion tests/modeldiffs/fastmri_model_size/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def sort_key(k):
jax_model_kwargs = dict(
augmented_and_preprocessed_input_batch=jax_batch,
mode=spec.ForwardPassMode.EVAL,
rng=jax.random.PRNGKey(0),
rng=jax.random.key(0),
update_batch_norm=False)

out_diff(
Expand Down
2 changes: 1 addition & 1 deletion tests/modeldiffs/fastmri_tanh/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def sort_key(k):
jax_model_kwargs = dict(
augmented_and_preprocessed_input_batch=jax_batch,
mode=spec.ForwardPassMode.EVAL,
rng=jax.random.PRNGKey(0),
rng=jax.random.key(0),
update_batch_norm=False)

out_diff(
Expand Down
2 changes: 1 addition & 1 deletion tests/modeldiffs/imagenet_resnet/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def sd_transform(sd):
jax_model_kwargs = dict(
augmented_and_preprocessed_input_batch=jax_batch,
mode=spec.ForwardPassMode.EVAL,
rng=jax.random.PRNGKey(0),
rng=jax.random.key(0),
update_batch_norm=False)

out_diff(
Expand Down
2 changes: 1 addition & 1 deletion tests/modeldiffs/imagenet_resnet/gelu_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
jax_model_kwargs = dict(
augmented_and_preprocessed_input_batch=jax_batch,
mode=spec.ForwardPassMode.EVAL,
rng=jax.random.PRNGKey(0),
rng=jax.random.key(0),
update_batch_norm=False)

out_diff(
Expand Down
2 changes: 1 addition & 1 deletion tests/modeldiffs/imagenet_resnet/silu_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
jax_model_kwargs = dict(
augmented_and_preprocessed_input_batch=jax_batch,
mode=spec.ForwardPassMode.EVAL,
rng=jax.random.PRNGKey(0),
rng=jax.random.key(0),
update_batch_norm=False)

out_diff(
Expand Down
2 changes: 1 addition & 1 deletion tests/modeldiffs/imagenet_vit/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def key_transform(k):
jax_model_kwargs = dict(
augmented_and_preprocessed_input_batch=jax_batch,
mode=spec.ForwardPassMode.EVAL,
rng=jax.random.PRNGKey(0),
rng=jax.random.key(0),
update_batch_norm=False)

out_diff(
Expand Down
2 changes: 1 addition & 1 deletion tests/modeldiffs/imagenet_vit_glu/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
jax_model_kwargs = dict(
augmented_and_preprocessed_input_batch=jax_batch,
mode=spec.ForwardPassMode.EVAL,
rng=jax.random.PRNGKey(0),
rng=jax.random.key(0),
update_batch_norm=False)

out_diff(
Expand Down
2 changes: 1 addition & 1 deletion tests/modeldiffs/imagenet_vit_map/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def sd_transform(sd):
jax_model_kwargs = dict(
augmented_and_preprocessed_input_batch=jax_batch,
mode=spec.ForwardPassMode.EVAL,
rng=jax.random.PRNGKey(0),
rng=jax.random.key(0),
update_batch_norm=False)

out_diff(
Expand Down
2 changes: 1 addition & 1 deletion tests/modeldiffs/imagenet_vit_postln/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
jax_model_kwargs = dict(
augmented_and_preprocessed_input_batch=jax_batch,
mode=spec.ForwardPassMode.EVAL,
rng=jax.random.PRNGKey(0),
rng=jax.random.key(0),
update_batch_norm=False)

out_diff(
Expand Down
2 changes: 1 addition & 1 deletion tests/modeldiffs/librispeech_conformer/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def sd_transform(sd):
jax_model_kwargs = dict(
augmented_and_preprocessed_input_batch=jax_batch,
mode=spec.ForwardPassMode.EVAL,
rng=jax.random.PRNGKey(0),
rng=jax.random.key(0),
update_batch_norm=False)

out_diff(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def sd_transform(sd):
jax_model_kwargs = dict(
augmented_and_preprocessed_input_batch=jax_batch,
mode=spec.ForwardPassMode.EVAL,
rng=jax.random.PRNGKey(0),
rng=jax.random.key(0),
update_batch_norm=False)

out_diff(
Expand Down
2 changes: 1 addition & 1 deletion tests/modeldiffs/librispeech_conformer_gelu/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def sd_transform(sd):
jax_model_kwargs = dict(
augmented_and_preprocessed_input_batch=jax_batch,
mode=spec.ForwardPassMode.EVAL,
rng=jax.random.PRNGKey(0),
rng=jax.random.key(0),
update_batch_norm=False)

out_diff(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def sd_transform(sd):
jax_model_kwargs = dict(
augmented_and_preprocessed_input_batch=jax_batch,
mode=spec.ForwardPassMode.EVAL,
rng=jax.random.PRNGKey(0),
rng=jax.random.key(0),
update_batch_norm=False)

out_diff(
Expand Down
2 changes: 1 addition & 1 deletion tests/modeldiffs/librispeech_deepspeech/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def sd_transform(sd):
jax_model_kwargs = dict(
augmented_and_preprocessed_input_batch=jax_batch,
mode=spec.ForwardPassMode.EVAL,
rng=jax.random.PRNGKey(0),
rng=jax.random.key(0),
update_batch_norm=False)

out_diff(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
jax_model_kwargs = dict(
augmented_and_preprocessed_input_batch=jax_batch,
mode=spec.ForwardPassMode.EVAL,
rng=jax.random.PRNGKey(0),
rng=jax.random.key(0),
update_batch_norm=False)

out_diff(
Expand Down
2 changes: 1 addition & 1 deletion tests/modeldiffs/librispeech_deepspeech_normaug/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
jax_model_kwargs = dict(
augmented_and_preprocessed_input_batch=jax_batch,
mode=spec.ForwardPassMode.EVAL,
rng=jax.random.PRNGKey(0),
rng=jax.random.key(0),
update_batch_norm=False)

out_diff(
Expand Down
2 changes: 1 addition & 1 deletion tests/modeldiffs/librispeech_deepspeech_tanh/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
jax_model_kwargs = dict(
augmented_and_preprocessed_input_batch=jax_batch,
mode=spec.ForwardPassMode.EVAL,
rng=jax.random.PRNGKey(0),
rng=jax.random.key(0),
update_batch_norm=False)

out_diff(
Expand Down
2 changes: 1 addition & 1 deletion tests/modeldiffs/ogbg/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def sd_transform(sd):
jax_model_kwargs = dict(
augmented_and_preprocessed_input_batch=jax_batch,
mode=spec.ForwardPassMode.EVAL,
rng=jax.random.PRNGKey(0),
rng=jax.random.key(0),
update_batch_norm=False)

out_diff(
Expand Down
2 changes: 1 addition & 1 deletion tests/modeldiffs/ogbg_gelu/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def sd_transform(sd):
jax_model_kwargs = dict(
augmented_and_preprocessed_input_batch=jax_batch,
mode=spec.ForwardPassMode.EVAL,
rng=jax.random.PRNGKey(0),
rng=jax.random.key(0),
update_batch_norm=False)

out_diff(
Expand Down
2 changes: 1 addition & 1 deletion tests/modeldiffs/ogbg_model_size/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def sd_transform(sd):
jax_model_kwargs = dict(
augmented_and_preprocessed_input_batch=jax_batch,
mode=spec.ForwardPassMode.EVAL,
rng=jax.random.PRNGKey(0),
rng=jax.random.key(0),
update_batch_norm=False)

out_diff(
Expand Down
2 changes: 1 addition & 1 deletion tests/modeldiffs/ogbg_silu/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def sd_transform(sd):
jax_model_kwargs = dict(
augmented_and_preprocessed_input_batch=jax_batch,
mode=spec.ForwardPassMode.EVAL,
rng=jax.random.PRNGKey(0),
rng=jax.random.key(0),
update_batch_norm=False)

out_diff(
Expand Down
2 changes: 1 addition & 1 deletion tests/modeldiffs/wmt/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def sd_transform(sd):
jax_model_kwargs = dict(
augmented_and_preprocessed_input_batch=jax_batch,
mode=spec.ForwardPassMode.EVAL,
rng=jax.random.PRNGKey(0),
rng=jax.random.key(0),
update_batch_norm=False)

out_diff(
Expand Down
2 changes: 1 addition & 1 deletion tests/modeldiffs/wmt_attention_temp/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def sd_transform(sd):
jax_model_kwargs = dict(
augmented_and_preprocessed_input_batch=jax_batch,
mode=spec.ForwardPassMode.EVAL,
rng=jax.random.PRNGKey(0),
rng=jax.random.key(0),
update_batch_norm=False)

out_diff(
Expand Down
2 changes: 1 addition & 1 deletion tests/modeldiffs/wmt_glu_tanh/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def sd_transform(sd):
jax_model_kwargs = dict(
augmented_and_preprocessed_input_batch=jax_batch,
mode=spec.ForwardPassMode.EVAL,
rng=jax.random.PRNGKey(0),
rng=jax.random.key(0),
update_batch_norm=False)

out_diff(
Expand Down
2 changes: 1 addition & 1 deletion tests/modeldiffs/wmt_post_ln/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def sd_transform(sd):
jax_model_kwargs = dict(
augmented_and_preprocessed_input_batch=jax_batch,
mode=spec.ForwardPassMode.EVAL,
rng=jax.random.PRNGKey(0),
rng=jax.random.key(0),
update_batch_norm=False)

out_diff(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_param_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def get_workload(workload):
pytorch_workload = PyTorchWmtWorkload()
else:
raise ValueError(f'Workload {workload} is not available.')
_ = jax_workload.init_model_fn(jax.random.PRNGKey(0))
_ = jax_workload.init_model_fn(jax.random.key(0))
_ = pytorch_workload.init_model_fn([0])
return jax_workload, pytorch_workload

Expand Down
2 changes: 1 addition & 1 deletion tests/test_param_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def get_workload(workload_name):
pytorch_workload = PyTorchWmtWorkload()
else:
raise ValueError(f'Workload {workload_name} is not available.')
_ = jax_workload.init_model_fn(jax.random.PRNGKey(0))
_ = jax_workload.init_model_fn(jax.random.key(0))
_ = pytorch_workload.init_model_fn([0])
return jax_workload, pytorch_workload

Expand Down
Loading

0 comments on commit 23881fd

Please sign in to comment.