Skip to content

Commit

Permalink
Simplify ffw tests
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 651139640
Change-Id: If038c9cf1383fe2aa7259829509252ab6f4b61d5
  • Loading branch information
Gemma Team authored and texasmichelle committed Jul 30, 2024
1 parent e6e2fee commit 335cb4a
Showing 1 changed file with 20 additions and 27 deletions.
47 changes: 20 additions & 27 deletions gemma/modules_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import logging

from absl.testing import absltest
from absl.testing import parameterized
from gemma import modules
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -246,41 +247,33 @@ def test_query_pre_attn_scalar_modifies_output(self):
)


class FeedForwardTest(absltest.TestCase):
class FeedForwardTest(parameterized.TestCase):

def test_ffw(self):
features = 2
hidden_dim = 3
batch_size = 2
inputs = jnp.arange(1, batch_size + 1)[:, None, None]
inputs = jnp.repeat(inputs, features, axis=-1)
ffw = modules.FeedForward(features=features, hidden_dim=hidden_dim)
params = {
'gating_einsum': jnp.ones((batch_size, features, hidden_dim)),
'linear': jnp.ones((hidden_dim, features)),
}

outputs = ffw.apply({'params': params}, inputs)

expected_val = [11.72758674, 47.99916]
expected_shape = (2, 1, 2)
np.testing.assert_array_almost_equal(outputs[:, 0, 0], expected_val)
self.assertEqual(outputs.shape, expected_shape)

def test_ffw_with_gqa(self):
@parameterized.parameters(
dict(
use_gqa=False,
),
dict(
use_gqa=True,
),
)
def test_ffw(self, use_gqa: bool):
features = 2
hidden_dim = 3
batch_size = 2
inputs = jnp.arange(1, batch_size + 1)[:, None, None]
inputs = jnp.repeat(inputs, features, axis=-1)
ffw = modules.FeedForward(
features=features, hidden_dim=hidden_dim, use_gqa=True
features=features, hidden_dim=hidden_dim, use_gqa=use_gqa
)
params = {
# Gating einsum dimensions are different with GQA.
'gating_einsum': jnp.ones((batch_size, hidden_dim, features)),
'linear': jnp.ones((hidden_dim, features)),
}

params = {'linear': jnp.ones((hidden_dim, features))}

# Different checkpoints have params saved in different order
if use_gqa:
params['gating_einsum'] = jnp.ones((batch_size, hidden_dim, features))
else:
params['gating_einsum'] = jnp.ones((batch_size, features, hidden_dim))

outputs = ffw.apply({'params': params}, inputs)

Expand Down

0 comments on commit 335cb4a

Please sign in to comment.