Skip to content

Commit

Permalink
Fixes to the Gemma fine-tuning colab tutorial.
Browse files Browse the repository at this point in the history
1 - Correct a bug where only the first sequence in a batch was being used in the loss calculation. (As the tutorial ultimately uses a batch size of 1, this was going unnoticed, but worth correcting, as I imagine folks might reuse this piece of code in their own projects)

2 - Change the attention mask to be causal w/ prefix (instead of just causal). I.e., an attention mask that is like the right diagram in Figure 3 of https://arxiv.org/pdf/1910.10683, instead of like the center diagram in Figure 3. Using the prefix in the attention mask is more appropriate for the fine-tuning task used in the tutorial.

PiperOrigin-RevId: 688741676
  • Loading branch information
saugenst authored and The gemma Authors committed Oct 28, 2024
1 parent 3d6e938 commit cc13ebe
Show file tree
Hide file tree
Showing 3 changed files with 216 additions and 37 deletions.
76 changes: 44 additions & 32 deletions gemma/colabs/fine_tuning_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"id": "XpSw-_4EEcoY"
},
Expand Down Expand Up @@ -75,7 +75,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"id": "8q5seOhcUBhx"
},
Expand All @@ -102,7 +102,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"id": "9PEefz8wEcoY"
},
Expand All @@ -118,7 +118,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"id": "yWaP_LPoEcoY"
},
Expand Down Expand Up @@ -166,7 +166,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "pg8SfQH0EcoY"
Expand Down Expand Up @@ -207,7 +207,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "TpyG5YW1EcoY"
Expand Down Expand Up @@ -235,7 +235,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "L9cjK0uxEcoY"
Expand Down Expand Up @@ -313,7 +313,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "xEA-97ioEcoY"
Expand Down Expand Up @@ -358,7 +358,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "XwFFs2mDEcoY"
Expand Down Expand Up @@ -506,7 +506,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "bYeduOaNEcoZ"
Expand Down Expand Up @@ -540,7 +540,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "VDlfziQVEcoZ"
Expand Down Expand Up @@ -574,7 +574,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "jWr6Sea_EcoZ"
Expand All @@ -590,7 +590,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "S6937NTjEcoZ"
Expand All @@ -617,7 +617,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "LbJa4S5WEcoZ"
Expand Down Expand Up @@ -648,7 +648,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "iEcV0XEEEcoZ"
Expand Down Expand Up @@ -688,11 +688,11 @@
" )\n",
"\n",
" # Exclude the last step as it does not appear in the targets.\n",
" logits = logits[0, :-1]\n",
" logits = logits[:, :-1, ...]\n",
"\n",
" # Similarly, the first token cannot be predicteds.\n",
" target_tokens = input_tokens[0, 1:]\n",
" target_mask = input_mask[0, 1:]\n",
" # Similarly, the first token cannot be predicted.\n",
" target_tokens = input_tokens[..., 1:]\n",
" target_mask = input_mask[..., 1:]\n",
"\n",
" # Convert the target labels into one-hot encoded vectors.\n",
" one_hot = jax.nn.one_hot(target_tokens, logits.shape[-1])\n",
Expand All @@ -718,20 +718,26 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "cbWfdHf0EcoZ"
},
"outputs": [],
"source": [
"def get_attention_mask_and_positions(example: jax.Array,\n",
" target_mask: jax.Array,\n",
" pad_id : int,\n",
" )-\u003e tuple[jax.Array, jax.Array]:\n",
" \"\"\"Builds the position and attention mask vectors from the given tokens.\"\"\"\n",
" pad_mask = example != pad_id\n",
" current_token_position = transformer_lib.build_positions_from_mask(pad_mask)\n",
" attention_mask = transformer_lib.make_causal_attn_mask(pad_mask)\n",
" input_mask = example != pad_id\n",
" current_token_position = transformer_lib.build_positions_from_mask(input_mask)\n",
" prefix_mask = (\n",
" input_mask.astype(jnp.int32) - target_mask.astype(jnp.int32)\n",
" ).astype(jnp.bool)\n",
" attention_mask = transformer_lib.make_causal_with_prefix_attn_mask(\n",
" input_mask, prefix_mask\n",
" )\n",
" return current_token_position, attention_mask"
]
},
Expand All @@ -746,7 +752,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "cPSfp7ZUEcoZ"
Expand Down Expand Up @@ -774,7 +780,9 @@
" \"\"\"\n",
"\n",
" # Build the position and attention mask vectors.\n",
" positions, attention_mask = get_attention_mask_and_positions(example.input_tokens, pad_id)\n",
" positions, attention_mask = get_attention_mask_and_positions(\n",
" example.input_tokens, example.target_mask, pad_id\n",
" )\n",
"\n",
" # Forward and backward passes\n",
" train_loss, grads = jax.value_and_grad(forward_and_loss_fn)(params,\n",
Expand All @@ -801,7 +809,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "yU4oR92YEcoa"
Expand All @@ -813,7 +821,9 @@
" pad_id: int,\n",
" example: TrainingInput,\n",
" ):\n",
" positions, attention_mask = get_attention_mask_and_positions(example.input_tokens, pad_id)\n",
" positions, attention_mask = get_attention_mask_and_positions(\n",
" example.input_tokens, example.target_mask, pad_id\n",
" )\n",
" val_loss = forward_and_loss_fn(params,\n",
" model=model,\n",
" input_tokens=example.input_tokens,\n",
Expand All @@ -834,7 +844,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "xT4bAqNLEcoa"
Expand Down Expand Up @@ -932,7 +942,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "7SL2VAmVEcoa"
Expand Down Expand Up @@ -966,7 +976,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "dQ1oCF10Ecod"
Expand All @@ -991,7 +1001,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "S5F3fk22Ecod"
Expand All @@ -1008,7 +1018,9 @@
"metadata": {
"accelerator": "GPU",
"colab": {
"private_outputs": true
"last_runtime": {},
"private_outputs": true,
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
Expand Down
49 changes: 44 additions & 5 deletions gemma/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,25 +337,64 @@ def __call__(


def make_causal_attn_mask(
input_mask: jax.Array,
input_mask: jax.Array, # Shape [B, L]
) -> jax.Array:
"""Attention mask in batch mode.
"""Makes a causal attention mask.
I.e., as in middle diagram of Figure 3 in https://arxiv.org/pdf/1910.10683.
Args:
input_mask: Input mask for the input. True for non-padded tokens only, else
False.
Returns:
Attention mask.
Attention mask of shape [B, L, L] (where B=batch dim and L=sequence dim).
"""
if len(input_mask.shape) != 2:
raise ValueError(
f'Input mask must be 2D (shape [B, L]), but got {input_mask.shape}.'
)
seq_len = input_mask.shape[-1]
causal_mask = jnp.tril(jnp.ones((seq_len, seq_len), dtype=jnp.bool))
attn_mask = input_mask[..., None, :]
causal_mask = jnp.tril(jnp.ones((seq_len, seq_len), dtype=jnp.bool_))
# Prefixes can be attended by all tokens
attn_mask *= causal_mask[None, ...]
return attn_mask


def make_causal_with_prefix_attn_mask(
input_mask: jax.Array, # Shape [B, L]
prefix_mask: jax.Array, # Shape [B, L]
) -> jax.Array:
"""Makes a causal with prefix attention mask.
I.e., as in the right diagram of Figure 3 in https://arxiv.org/pdf/1910.10683.
Args:
input_mask: Input mask for the input. True for non-padded tokens only, else
False.
prefix_mask: Input mask for the prefix. True for prefix tokens only, else
False.
Returns:
Attention mask of shape [B, L, L] (where B=batch dim and L=sequence dim).
"""
if len(input_mask.shape) != 2:
raise ValueError(
f'Input mask must be 2D (shape [B, L]), but got {input_mask.shape}.'
)
if len(prefix_mask.shape) != 2:
raise ValueError(
f'Prefix mask must be 2D (shape [B, L]), but got {prefix_mask.shape}.'
)
seq_len = input_mask.shape[-1]
causal_mask = jnp.tril(jnp.ones((seq_len, seq_len), dtype=jnp.bool))
prefix_mask = jnp.tile(jnp.expand_dims(prefix_mask, axis=1), [1, seq_len, 1])
causal_or_prefix_mask = jnp.logical_or(causal_mask, prefix_mask)
attn_mask = input_mask[..., None, :]
attn_mask *= causal_or_prefix_mask
return attn_mask


def build_positions_from_mask(input_mask: jax.Array) -> jax.Array:
"""Computes the `positions` from the `input_mask`.
Expand Down
Loading

0 comments on commit cc13ebe

Please sign in to comment.