Skip to content

Commit

Permalink
rms norm test succeed
Browse files Browse the repository at this point in the history
  • Loading branch information
wusar committed Aug 15, 2023
1 parent 15814d3 commit 804ead9
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions python/unittest/test_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,8 @@ def test_layernorm_internal(
elif layer_norm_type == "rmsnorm":
gt = (
torch_input
* torch.sqrt(
torch.mean(torch_input**2, dim=-1, keepdim=True) + 1e-5
).numpy()
)
* torch.rsqrt(torch_input.pow(2).mean(-1, keepdim=True) + 1e-5)
).numpy()

# test if the result is correct
max_abs_error = np.max(np.abs(output_tensor_host - gt))
Expand Down

0 comments on commit 804ead9

Please sign in to comment.