Skip to content

Commit

Permalink
Clean up positional embeddings.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 686236709
  • Loading branch information
The gemma Authors committed Oct 28, 2024
1 parent 54718dc commit 3d6e938
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 35 deletions.
2 changes: 0 additions & 2 deletions gemma/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,13 +153,11 @@ def __call__(
query_proj = positional_embeddings.apply_rope(
query_proj,
segment_pos,
head_dim=self.head_dim,
)
query_scaled = query_proj * self.query_pre_attn_scalar
key_proj = positional_embeddings.apply_rope(
key_proj,
segment_pos,
head_dim=self.head_dim,
)

# Cache is left aligned.
Expand Down
55 changes: 41 additions & 14 deletions gemma/positional_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,34 +35,61 @@


def add_positional_embedding(
input_embedding: jax.Array,
position: int,
inputs: jax.Array,
positions: jax.Array,
max_wavelength: int = _MAX_WAVELENGTH,
) -> jax.Array:
"""Adds positional embeddings to input embeddings."""
embed_dim = input_embedding.shape[-1]
num_timescales = embed_dim // 2
"""Adds positional embeddings to inputs.
Let B denote batch size, L denote sequence length, N denote number of heads,
and H denote head dimension. Note that H must be divisible by 2.
Args:
inputs: Array of shape [B, L, N, H].
positions: Array of shape [B, L].
max_wavelength: The maximum wavelength.
Returns:
Array of shape [B, L, N, H].
"""
head_dim = inputs.shape[-1]
num_timescales = head_dim // 2
log_timescale_increment = jnp.log(float(max_wavelength)) / jnp.maximum(
jnp.asarray(num_timescales, dtype=jnp.float32) - 1, 1
)
inv_timescales = jnp.exp(
jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment
)
scaled_time = position * inv_timescales
signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)])
signal = jnp.pad(signal, [[0, jnp.mod(embed_dim, 2)]])
scaled_time = (
positions[..., jnp.newaxis] * inv_timescales[jnp.newaxis, jnp.newaxis, :]
)
scaled_time = scaled_time[..., jnp.newaxis, :]
signal = jnp.concatenate(
[jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=-1
)
position_embedding = signal.astype(jnp.float32)

return input_embedding + position_embedding
return inputs + position_embedding


def apply_rope(
inputs: jax.Array, # [B, L]
positions: jax.Array, # [B, L]
head_dim: int,
inputs: jax.Array,
positions: jax.Array,
max_wavelength: int = _MAX_WAVELENGTH,
) -> jax.Array:
"""Applies RoPE."""
"""Applies RoPE.
Let B denote batch size, L denote sequence length, N denote number of heads,
and H denote head dimension. Note that H must be divisible by 2.
Args:
inputs: Array of shape [B, L, N, H].
positions: Array of shape [B, L].
max_wavelength: The maximum wavelength.
Returns:
Array of shape [B, L, N, H].
"""
head_dim = inputs.shape[-1]
fraction = 2 * jnp.arange(0, head_dim // 2) / head_dim
timescale = max_wavelength**fraction

Expand Down
35 changes: 16 additions & 19 deletions gemma/positional_embeddings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,50 +39,47 @@ class PositionalEmbeddingsTest(parameterized.TestCase):

@parameterized.parameters(
dict(
input_embedding_shape=(2, 1, 1, 5),
position=3,
input_embedding_shape=(2, 1, 1, 6),
positions=[[1], [0]],
max_wavelength=100,
expected=[[[[1.1411201, 1.0299965, 0.0100075, 1.99955, 1.0]]],
[[[1.1411201, 1.0299965, 0.0100075, 1.99955, 1.0]]]]
expected=[
[[[1.841471, 1.099833, 1.01, 1.540302, 1.995004, 1.99995]]],
[[[1.0, 1.0, 1.0, 2.0, 2.0, 2.0]]],
],
)
)
def test_adds_positional_embeddings(
self, input_embedding_shape, position, max_wavelength, expected
def test_add_positional_embeddings(
self, input_embedding_shape, positions, max_wavelength, expected
):
outputs = positional_embeddings.add_positional_embedding(
jnp.ones(input_embedding_shape), position, max_wavelength
jnp.ones(input_embedding_shape), jnp.array(positions), max_wavelength
)
np.testing.assert_array_almost_equal(outputs, jnp.array(expected))

@parameterized.parameters(
dict(
input_embedding_shape=(2, 1, 2, 4),
position=3,
head_dim=4,
positions=[[1], [0]],
max_wavelength=100,
expected=[
[[
[-1.1311126, 0.6598157, -0.8488725, 1.2508571],
[-1.1311126, 0.6598157, -0.8488725, 1.2508571],
]],
[[
[-1.1311126, 0.6598157, -0.8488725, 1.2508571],
[-1.1311126, 0.6598157, -0.8488725, 1.2508571],
[-0.30116868, 0.89517075, 1.3817732, 1.0948375],
[-0.30116868, 0.89517075, 1.3817732, 1.0948375],
]],
[[[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]],
],
)
)
def test_rope_positional_embeddings(
self, input_embedding_shape, position, head_dim, max_wavelength, expected
self, input_embedding_shape, positions, max_wavelength, expected
):
outputs = positional_embeddings.apply_rope(
jnp.ones(input_embedding_shape),
jnp.array([[position]]),
head_dim,
jnp.array(positions),
max_wavelength,
)
np.testing.assert_array_almost_equal(outputs, jnp.array(expected))


if __name__ == "__main__":
if __name__ == '__main__':
absltest.main()

0 comments on commit 3d6e938

Please sign in to comment.