From 804ead975e573a35caf643eeb5e8e7d70f0b99ec Mon Sep 17 00:00:00 2001 From: Lifanwu Date: Tue, 15 Aug 2023 16:53:23 +0000 Subject: [PATCH] rms norm test succeed --- python/unittest/test_layernorm.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/unittest/test_layernorm.py b/python/unittest/test_layernorm.py index b339f5601..713cbde6a 100644 --- a/python/unittest/test_layernorm.py +++ b/python/unittest/test_layernorm.py @@ -46,10 +46,8 @@ def test_layernorm_internal( elif layer_norm_type == "rmsnorm": gt = ( torch_input - * torch.sqrt( - torch.mean(torch_input**2, dim=-1, keepdim=True) + 1e-5 - ).numpy() - ) + * torch.rsqrt(torch_input.pow(2).mean(-1, keepdim=True) + 1e-5) + ).numpy() # test if the result is correct max_abs_error = np.max(np.abs(output_tensor_host - gt))