diff --git a/gemma/colabs/fine_tuning_tutorial.ipynb b/gemma/colabs/fine_tuning_tutorial.ipynb index 7543c74..53545a5 100644 --- a/gemma/colabs/fine_tuning_tutorial.ipynb +++ b/gemma/colabs/fine_tuning_tutorial.ipynb @@ -45,7 +45,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "id": "XpSw-_4EEcoY" }, @@ -75,7 +75,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "id": "8q5seOhcUBhx" }, @@ -102,7 +102,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "id": "9PEefz8wEcoY" }, @@ -118,7 +118,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "id": "yWaP_LPoEcoY" }, @@ -166,7 +166,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "cellView": "form", "id": "pg8SfQH0EcoY" @@ -207,7 +207,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "cellView": "form", "id": "TpyG5YW1EcoY" @@ -235,7 +235,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "cellView": "form", "id": "L9cjK0uxEcoY" @@ -313,7 +313,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "cellView": "form", "id": "xEA-97ioEcoY" @@ -358,7 +358,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "cellView": "form", "id": "XwFFs2mDEcoY" @@ -506,7 +506,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "cellView": "form", "id": "bYeduOaNEcoZ" @@ -540,7 +540,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "cellView": "form", "id": "VDlfziQVEcoZ" @@ -574,7 +574,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "cellView": "form", "id": "jWr6Sea_EcoZ" @@ -590,7 +590,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "cellView": "form", "id": "S6937NTjEcoZ" @@ -617,7 +617,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "cellView": "form", "id": "LbJa4S5WEcoZ" @@ -648,7 +648,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "cellView": "form", "id": "iEcV0XEEEcoZ" @@ -688,11 +688,11 @@ " )\n", "\n", " # Exclude the last step as it does not appear in the targets.\n", - " logits = logits[0, :-1]\n", + " logits = logits[:, :-1, ...]\n", "\n", - " # Similarly, the first token cannot be predicteds.\n", - " target_tokens = input_tokens[0, 1:]\n", - " target_mask = input_mask[0, 1:]\n", + " # Similarly, the first token cannot be predicted.\n", + " target_tokens = input_tokens[..., 1:]\n", + " target_mask = input_mask[..., 1:]\n", "\n", " # Convert the target labels into one-hot encoded vectors.\n", " one_hot = jax.nn.one_hot(target_tokens, logits.shape[-1])\n", @@ -718,7 +718,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "cellView": "form", "id": "cbWfdHf0EcoZ" @@ -726,12 +726,18 @@ "outputs": [], "source": [ "def get_attention_mask_and_positions(example: jax.Array,\n", + " target_mask: jax.Array,\n", " pad_id : int,\n", " )-\u003e tuple[jax.Array, jax.Array]:\n", " \"\"\"Builds the position and attention mask vectors from the given tokens.\"\"\"\n", - " pad_mask = example != pad_id\n", - " current_token_position = transformer_lib.build_positions_from_mask(pad_mask)\n", - " attention_mask = transformer_lib.make_causal_attn_mask(pad_mask)\n", + " input_mask = example != pad_id\n", + " current_token_position = transformer_lib.build_positions_from_mask(input_mask)\n", + " prefix_mask = (\n", + " input_mask.astype(jnp.int32) - target_mask.astype(jnp.int32)\n", + " ).astype(jnp.bool)\n", + " attention_mask = transformer_lib.make_causal_with_prefix_attn_mask(\n", + " input_mask, prefix_mask\n", + " )\n", " return current_token_position, attention_mask" ] }, @@ -746,7 +752,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "cellView": "form", "id": "cPSfp7ZUEcoZ" @@ -774,7 +780,9 @@ " \"\"\"\n", "\n", " # Build the position and attention mask vectors.\n", - " positions, attention_mask = get_attention_mask_and_positions(example.input_tokens, pad_id)\n", + " positions, attention_mask = get_attention_mask_and_positions(\n", + " example.input_tokens, example.target_mask, pad_id\n", + " )\n", "\n", " # Forward and backward passes\n", " train_loss, grads = jax.value_and_grad(forward_and_loss_fn)(params,\n", @@ -801,7 +809,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "cellView": "form", "id": "yU4oR92YEcoa" @@ -813,7 +821,9 @@ " pad_id: int,\n", " example: TrainingInput,\n", " ):\n", - " positions, attention_mask = get_attention_mask_and_positions(example.input_tokens, pad_id)\n", + " positions, attention_mask = get_attention_mask_and_positions(\n", + " example.input_tokens, example.target_mask, pad_id\n", + " )\n", " val_loss = forward_and_loss_fn(params,\n", " model=model,\n", " input_tokens=example.input_tokens,\n", @@ -834,7 +844,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "cellView": "form", "id": "xT4bAqNLEcoa" @@ -932,7 +942,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "cellView": "form", "id": "7SL2VAmVEcoa" @@ -966,7 +976,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "cellView": "form", "id": "dQ1oCF10Ecod" @@ -991,7 +1001,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "cellView": "form", "id": "S5F3fk22Ecod" @@ -1008,7 +1018,9 @@ "metadata": { "accelerator": "GPU", "colab": { - "private_outputs": true + "last_runtime": {}, + "private_outputs": true, + "provenance": [] }, "kernelspec": { "display_name": "Python 3", diff --git a/gemma/transformer.py b/gemma/transformer.py index 4c5b1d4..52e34ad 100644 --- a/gemma/transformer.py +++ b/gemma/transformer.py @@ -337,25 +337,64 @@ def __call__( def make_causal_attn_mask( - input_mask: jax.Array, + input_mask: jax.Array, # Shape [B, L] ) -> jax.Array: - """Attention mask in batch mode. + """Makes a causal attention mask. + + I.e., as in middle diagram of Figure 3 in https://arxiv.org/pdf/1910.10683. Args: input_mask: Input mask for the input. True for non-padded tokens only, else False. Returns: - Attention mask. + Attention mask of shape [B, L, L] (where B=batch dim and L=sequence dim). """ + if len(input_mask.shape) != 2: + raise ValueError( + f'Input mask must be 2D (shape [B, L]), but got {input_mask.shape}.' + ) seq_len = input_mask.shape[-1] + causal_mask = jnp.tril(jnp.ones((seq_len, seq_len), dtype=jnp.bool)) attn_mask = input_mask[..., None, :] - causal_mask = jnp.tril(jnp.ones((seq_len, seq_len), dtype=jnp.bool_)) - # Prefixes can be attended by all tokens attn_mask *= causal_mask[None, ...] return attn_mask +def make_causal_with_prefix_attn_mask( + input_mask: jax.Array, # Shape [B, L] + prefix_mask: jax.Array, # Shape [B, L] +) -> jax.Array: + """Makes a causal with prefix attention mask. + + I.e., as in the right diagram of Figure 3 in https://arxiv.org/pdf/1910.10683. + + Args: + input_mask: Input mask for the input. True for non-padded tokens only, else + False. + prefix_mask: Input mask for the prefix. True for prefix tokens only, else + False. + + Returns: + Attention mask of shape [B, L, L] (where B=batch dim and L=sequence dim). + """ + if len(input_mask.shape) != 2: + raise ValueError( + f'Input mask must be 2D (shape [B, L]), but got {input_mask.shape}.' + ) + if len(prefix_mask.shape) != 2: + raise ValueError( + f'Prefix mask must be 2D (shape [B, L]), but got {prefix_mask.shape}.' + ) + seq_len = input_mask.shape[-1] + causal_mask = jnp.tril(jnp.ones((seq_len, seq_len), dtype=jnp.bool)) + prefix_mask = jnp.tile(jnp.expand_dims(prefix_mask, axis=1), [1, seq_len, 1]) + causal_or_prefix_mask = jnp.logical_or(causal_mask, prefix_mask) + attn_mask = input_mask[..., None, :] + attn_mask *= causal_or_prefix_mask + return attn_mask + + def build_positions_from_mask(input_mask: jax.Array) -> jax.Array: """Computes the `positions` from the `input_mask`. diff --git a/gemma/transformer_test.py b/gemma/transformer_test.py index 68b2812..2284ff2 100644 --- a/gemma/transformer_test.py +++ b/gemma/transformer_test.py @@ -390,5 +390,133 @@ def test_query_pre_attn_scalar( self.assertEqual(config.query_pre_attn_scalar(), expected_scalar) +class TransformerUtilsTest(parameterized.TestCase): + + def test_make_causal_attn_mask(self): + input_mask = jnp.array( + [[True, True, True, False, False], [True, True, True, True, False]] + ) + causal_attn_mask = transformer_lib.make_causal_attn_mask(input_mask) + + expected_mask_shape = tuple(list(input_mask.shape) + [input_mask.shape[-1]]) + self.assertEqual(causal_attn_mask.shape, expected_mask_shape) + self.assertEqual(causal_attn_mask.dtype, jnp.bool) + + # This reduces the attention mask, to a mask of which tokens are ever (once + # or more) attended to. It should be the same as the input mask, if + # attention mask is correct. + token_ever_attended_mask = jnp.sum( + jnp.astype(causal_attn_mask, jnp.int32), axis=1, dtype=jnp.bool + ) + np.testing.assert_array_equal(input_mask, token_ever_attended_mask) + + # Iterate over sequences in batch. + for i in range(causal_attn_mask.shape[0]): + + last_number_of_tokens_attended = 0 + # Iterate over tokens in sequence. + for j in range(causal_attn_mask.shape[1]): + if not input_mask[i, j]: + break + number_of_tokens_attended = jnp.sum( + jnp.astype(causal_attn_mask[i, j, :], jnp.int32) + ) + # Each token in the sequence pays attention to one more token than the + # previous token in the sequence. + self.assertEqual( + number_of_tokens_attended, last_number_of_tokens_attended + 1 + ) + last_number_of_tokens_attended = number_of_tokens_attended + + def test_make_causal_attn_mask_fails_with_bad_input_mask_shape(self): + bad_input_mask = jnp.array([[[True]]]) + with self.assertRaises(ValueError): + transformer_lib.make_causal_attn_mask(bad_input_mask) + + def test_make_causal_with_prefix_attn_mask(self): + input_mask = jnp.array( + [[True, True, True, False, False], [True, True, True, True, False]] + ) + prefix_mask = jnp.array( + [[True, True, False, False, False], [True, True, False, False, False]] + ) + causal_with_prefix_attn_mask = ( + transformer_lib.make_causal_with_prefix_attn_mask( + input_mask, prefix_mask + ) + ) + + expected_mask_shape = tuple(list(input_mask.shape) + [input_mask.shape[-1]]) + self.assertEqual(causal_with_prefix_attn_mask.shape, expected_mask_shape) + self.assertEqual(causal_with_prefix_attn_mask.dtype, jnp.bool) + + # This reduces the attention mask, to a mask of which tokens are ever (once + # or more) attended to. It should be the same as the input mask, if + # attention mask is correct. + token_ever_attended_mask = jnp.sum( + jnp.astype(causal_with_prefix_attn_mask, jnp.int32), + axis=1, + dtype=jnp.bool, + ) + np.testing.assert_array_equal(input_mask, token_ever_attended_mask) + + # This reduces the attention mask, to a mask of which tokens are *always* + # attended to. It should be the same as the prefix mask, if attention mask + # is correct. + token_always_attended_mask = jnp.prod( + jnp.astype(causal_with_prefix_attn_mask, jnp.int32), + axis=1, + dtype=jnp.bool, + ) + np.testing.assert_array_equal(prefix_mask, token_always_attended_mask) + + # Iterate over sequences in batch. + for i in range(causal_with_prefix_attn_mask.shape[0]): + + last_number_of_tokens_attended = 0 + # Iterate over tokens in sequence. + for j in range(causal_with_prefix_attn_mask.shape[1]): + if not input_mask[i, j]: + break + number_of_tokens_attended = jnp.sum( + jnp.astype(causal_with_prefix_attn_mask[i, j, :], jnp.int32) + ) + + if prefix_mask[i, j]: + # Each token in the prefix part of the sequence pays attention to all + # the tokens in the prefix part of the sequence. + self.assertEqual( + number_of_tokens_attended, jnp.sum(prefix_mask[i, :]) + ) + else: + # Each token in the non-prefix part of the sequence pays attention to + # one more token than the previous token in the sequence. + self.assertEqual( + number_of_tokens_attended, last_number_of_tokens_attended + 1 + ) + + last_number_of_tokens_attended = number_of_tokens_attended + + def test_make_causal_with_prefix_attn_mask_fails_with_bad_input_mask_shape( + self, + ): + bad_input_mask = jnp.array([[[True]]]) + prefix_mask = jnp.array([[True], [True]]) + with self.assertRaises(ValueError): + transformer_lib.make_causal_with_prefix_attn_mask( + bad_input_mask, prefix_mask + ) + + def test_make_causal_with_prefix_attn_mask_fails_with_bad_prefix_mask_shape( + self, + ): + input_mask = jnp.array([[True], [True]]) + bad_prefix_mask = jnp.array([[[True]]]) + with self.assertRaises(ValueError): + transformer_lib.make_causal_with_prefix_attn_mask( + input_mask, bad_prefix_mask + ) + + if __name__ == '__main__': absltest.main()