From 335cb4acc8e92618c504722c940c38ef1d84490a Mon Sep 17 00:00:00 2001 From: Gemma Team Date: Wed, 10 Jul 2024 13:49:08 -0700 Subject: [PATCH] Simplify ffw tests PiperOrigin-RevId: 651139640 Change-Id: If038c9cf1383fe2aa7259829509252ab6f4b61d5 --- gemma/modules_test.py | 47 ++++++++++++++++++------------------------- 1 file changed, 20 insertions(+), 27 deletions(-) diff --git a/gemma/modules_test.py b/gemma/modules_test.py index 36f33e2..fa724c0 100644 --- a/gemma/modules_test.py +++ b/gemma/modules_test.py @@ -17,6 +17,7 @@ import logging from absl.testing import absltest +from absl.testing import parameterized from gemma import modules import jax import jax.numpy as jnp @@ -246,41 +247,33 @@ def test_query_pre_attn_scalar_modifies_output(self): ) -class FeedForwardTest(absltest.TestCase): +class FeedForwardTest(parameterized.TestCase): - def test_ffw(self): - features = 2 - hidden_dim = 3 - batch_size = 2 - inputs = jnp.arange(1, batch_size + 1)[:, None, None] - inputs = jnp.repeat(inputs, features, axis=-1) - ffw = modules.FeedForward(features=features, hidden_dim=hidden_dim) - params = { - 'gating_einsum': jnp.ones((batch_size, features, hidden_dim)), - 'linear': jnp.ones((hidden_dim, features)), - } - - outputs = ffw.apply({'params': params}, inputs) - - expected_val = [11.72758674, 47.99916] - expected_shape = (2, 1, 2) - np.testing.assert_array_almost_equal(outputs[:, 0, 0], expected_val) - self.assertEqual(outputs.shape, expected_shape) - - def test_ffw_with_gqa(self): + @parameterized.parameters( + dict( + use_gqa=False, + ), + dict( + use_gqa=True, + ), + ) + def test_ffw(self, use_gqa: bool): features = 2 hidden_dim = 3 batch_size = 2 inputs = jnp.arange(1, batch_size + 1)[:, None, None] inputs = jnp.repeat(inputs, features, axis=-1) ffw = modules.FeedForward( - features=features, hidden_dim=hidden_dim, use_gqa=True + features=features, hidden_dim=hidden_dim, use_gqa=use_gqa ) - params = { - # Gating einsum dimensions are different with GQA. - 'gating_einsum': jnp.ones((batch_size, hidden_dim, features)), - 'linear': jnp.ones((hidden_dim, features)), - } + + params = {'linear': jnp.ones((hidden_dim, features))} + + # Different checkpoints have params saved in different order + if use_gqa: + params['gating_einsum'] = jnp.ones((batch_size, hidden_dim, features)) + else: + params['gating_einsum'] = jnp.ones((batch_size, features, hidden_dim)) outputs = ffw.apply({'params': params}, inputs)