Skip to content

Commit

Permalink
Move post-attention norm prior to the residual connection.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 649232943
Change-Id: Ia23f177f1342ee668322221a55dda816404c094c
  • Loading branch information
texasmichelle committed Jul 30, 2024
1 parent 87c17fc commit 205e096
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 17 deletions.
23 changes: 10 additions & 13 deletions gemma/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,16 +233,16 @@ def setup(self):
attn_logits_soft_cap=self.attn_logits_soft_cap,
sliding_window_size=self.sliding_window_size,
)
self.post_attn_norm = None
if self.use_post_attn_norm:
self.post_attn_norm = layers.RMSNorm()

self.pre_ffw_norm = layers.RMSNorm()
self.mlp = FeedForward(features=self.embed_dim, hidden_dim=self.hidden_dim)
self.post_ffw_norm = None
if self.use_post_ffw_norm:
self.post_ffw_norm = layers.RMSNorm()

self.post_attn_norm = None
if self.use_post_attn_norm:
self.post_attn_norm = layers.RMSNorm()

def __call__(
self,
x: jax.Array,
Expand All @@ -257,15 +257,12 @@ def __call__(
cache,
attn_mask,
)
attn_output += x
residual = attn_output
attn_output = self.pre_ffw_norm(attn_output)

if self.use_post_attn_norm:
if self.post_attn_norm is not None:
attn_output = self.post_attn_norm(attn_output)

outputs = self.mlp(attn_output)
if self.use_post_ffw_norm:
attn_output += x
outputs = self.pre_ffw_norm(attn_output)
outputs = self.mlp(outputs)
if self.post_ffw_norm is not None:
outputs = self.post_ffw_norm(outputs)
outputs = residual + outputs
outputs += attn_output
return cache, outputs
11 changes: 7 additions & 4 deletions gemma/modules_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def test_block(self):
self.assertEqual(new_cache['k'].shape, expected_cache_shape)
self.assertEqual(outputs.shape, expected_output_shape)

def test_post_attention_norm_preserves_output(self):
def test_post_attention_norm_modifies_output(self):
num_heads = 1
embed_dim = 1
head_dim = 2
Expand Down Expand Up @@ -255,9 +255,12 @@ def test_post_attention_norm_preserves_output(self):
normed_output, unnormed_output = all_outputs # pylint: disable=unbalanced-tuple-unpacking
logging.info('normed_output: %s', normed_output)
logging.info('unnormed_output: %s', unnormed_output)
# TODO(b/350763078): Fix bug in the attention implementation. Normed and
# unnormed outputs should not be equal.
np.testing.assert_array_equal(normed_output, unnormed_output)
np.testing.assert_raises(
AssertionError,
np.testing.assert_array_equal,
normed_output,
unnormed_output,
)

def test_post_ffw_norm_preserves_output(self):
num_heads = 1
Expand Down

0 comments on commit 205e096

Please sign in to comment.