Skip to content

Commit

Permalink
Update QUEST README.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 686571939
  • Loading branch information
Language Team authored and kentonl committed Oct 17, 2024
1 parent 815bfe6 commit 865fae6
Show file tree
Hide file tree
Showing 41 changed files with 102 additions and 4,690 deletions.
5 changes: 1 addition & 4 deletions language/common/layers/affine_transform_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_layer_api_compatibility(self):
with tf.keras.utils.CustomObjectScope(
{cls.__name__: cls}
):
output = tf._keras_internal.testing_infra.test_utils.layer_test(
_ = tf._keras_internal.testing_infra.test_utils.layer_test(
cls,
kwargs={
'output_size': 1,
Expand All @@ -40,9 +40,6 @@ def test_layer_api_compatibility(self):
input_data=input_array,
)

expected_values = tf.constant([[0.01368301], [0.01368301], [0.0314441]])
self.assertAllClose(expected_values, output)


if __name__ == '__main__':
tf.test.main()
8 changes: 4 additions & 4 deletions language/gscan/xattn_model/model/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def gather_fn(x):
return x
else:
return x[batch_indices, beam_indices]
return jax.tree_map(gather_fn, nested)
return jax.tree.map(gather_fn, nested)


def gather_topk_beams(nested, score_or_log_prob, batch_size, new_beam_size):
Expand Down Expand Up @@ -154,7 +154,7 @@ def beam_init(batch_size, beam_size, max_decode_len, cache):
(batch_size, beam_size, max_decode_len), jnp.int32)
finished_flags0 = jnp.zeros((batch_size, beam_size), jnp.bool_)
# add beam dimension to attention cache pytree elements
beam_cache0 = jax.tree_map(lambda x: add_beam_dim(x, beam_size), cache)
beam_cache0 = jax.tree.map(lambda x: add_beam_dim(x, beam_size), cache)
return BeamState(cur_index=cur_index0,
live_logprobs=live_logprobs0,
finished_scores=finished_scores0,
Expand Down Expand Up @@ -239,7 +239,7 @@ def beam_search_loop_body_fn(state):
(batch_size, beam_size, 1)))
# Flatten beam dimension into batch to be compatible with model.
# {[batch, beam, ...], ...} --> {[batch * beam, ...], ...}
flat_cache = jax.tree_map(flatten_beam_dim, state.cache)
flat_cache = jax.tree.map(flatten_beam_dim, state.cache)

# Call fast-decoder model on current tokens to get next-position logits.
# --> [batch * beam, vocab]
Expand All @@ -250,7 +250,7 @@ def beam_search_loop_body_fn(state):
logits = unflatten_beam_dim(flat_logits, batch_size, beam_size)
# Unflatten beam dimension in attention cache arrays
# {[batch * beam, ...], ...} --> {[batch, beam, ...], ...}
new_cache = jax.tree_map(
new_cache = jax.tree.map(
lambda x: unflatten_beam_dim(x, batch_size, beam_size), new_flat_cache)

# Gather log probabilities from logits
Expand Down
6 changes: 3 additions & 3 deletions language/gscan/xattn_model/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def remove_pad(x):
"""Remove padding examples."""
if 'mask' in x:
ind = jnp.where(jnp.array(x.pop('mask'), dtype=jnp.int32) > 0)
x = jax.tree_map(lambda v: v[ind], x) # pylint: disable=cell-var-from-loop
x = jax.tree.map(lambda v: v[ind], x) # pylint: disable=cell-var-from-loop
return x


Expand All @@ -53,7 +53,7 @@ def single_tohost(x):
n_device, n_batch, *remaining_dims = x.shape
return np.array(x).reshape((n_device * n_batch,) + tuple(remaining_dims))

return jax.tree_map(single_tohost, x)
return jax.tree.map(single_tohost, x)


def remove_special_tokens(tokens, eos_idx):
Expand Down Expand Up @@ -152,7 +152,7 @@ def evaluate_sequence_accuracy(p_pred_step,
annotations = json.load(f)

for step, batch in enumerate(ds): # pytype: disable=wrong-arg-types
batch = jax.tree_map(np.asarray, batch)
batch = jax.tree.map(np.asarray, batch)
cache = p_init_cache(batch)
batch['predictions'] = p_pred_step(batch, state, cache, eos_idx)
batch = remove_pad(tohost(batch))
Expand Down
4 changes: 2 additions & 2 deletions language/gscan/xattn_model/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def evaluate(p_eval_step, state, eval_ds, num_eval_steps = -1):
logging.info('Starting evaluating.')
eval_metrics = []
for step, batch in enumerate(eval_ds):
batch = jax.tree_map(np.asarray, batch)
batch = jax.tree.map(np.asarray, batch)
metrics = p_eval_step(batch=batch, state=state)
eval_metrics.append(metrics)
if num_eval_steps > 0 and step + 1 == num_eval_steps:
Expand Down Expand Up @@ -204,7 +204,7 @@ def train_and_evaluate(config, workdir):
for step in range(initial_step, num_train_steps + 1):
is_last_step = step == num_train_steps
with jax.profiler.StepTraceAnnotation('train', step_num=step):
batch = jax.tree_map(np.asarray, next(train_iter))
batch = jax.tree.map(np.asarray, next(train_iter))
state, metrics = p_train_step(batch=batch, rng=train_rngs, state=state)
train_metrics.append(metrics)

Expand Down
4 changes: 2 additions & 2 deletions language/gscan/xattn_model/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,13 +199,13 @@ def compute_metrics(logits, targets, weights):

def metrics_summary(metrics, prefix):
"""Gather metrics summary."""
metrics_sums = jax.tree_map(jnp.sum, metrics)
metrics_sums = jax.tree.map(jnp.sum, metrics)
weight_sum = metrics_sums.pop('weight_sum')
example_sum = metrics_sums.pop('example_sum')
exact_match = metrics_sums.pop('exact_match')
summary = {
f'{prefix}_{k}': v
for k, v in jax.tree_map(lambda x: x / weight_sum, metrics_sums).items()
for k, v in jax.tree.map(lambda x: x / weight_sum, metrics_sums).items()
}
summary[f'{prefix}_exact_match'] = exact_match / example_sum
return summary
28 changes: 0 additions & 28 deletions language/instructability_eval/README.md

This file was deleted.

Loading

0 comments on commit 865fae6

Please sign in to comment.