-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refactor: RMSNorm #59
Merged
Merged
Changes from all commits
Commits
Show all changes
43 commits
Select commit
Hold shift + click to select a range
23f7a96
update cpp ext of rms norm
zhangzefeng92 b0aa03d
Update test_rms_lightlm.py
zhangzefeng92 0517125
Update test_rms_lightlm.py
zhangzefeng92 d675c9f
Update extensions.cpp
zhangzefeng92 3f100c0
modify extensions.cpp
zhangzefeng92 fce081f
fix python lint
zhangzefeng92 7e59912
fix python lint
zhangzefeng92 bd74cb0
fix python lint
zhangzefeng92 5027589
fix python lint
zhangzefeng92 7fc26c7
fix rms norm
zhangzefeng92 136e8e1
modify rms norm
zhangzefeng92 9277e0a
update cpp ext of rms norm
zhangzefeng92 1be8963
Update test_rms_lightlm.py
zhangzefeng92 6cd9db6
Update test_rms_lightlm.py
zhangzefeng92 fd486a6
Update extensions.cpp
zhangzefeng92 d54413b
modify extensions.cpp
zhangzefeng92 6dad1ed
fix python lint
zhangzefeng92 0a8a28d
fix python lint
zhangzefeng92 4ef9d48
fix python lint
zhangzefeng92 241492b
fix python lint
zhangzefeng92 5cafafe
fix rms norm
zhangzefeng92 3d9a805
modify rms norm
zhangzefeng92 33d3618
Merge branch 'main' into zzf/fix_rmsnorm
zhangzefeng92 28274a0
Merge branch 'zzf/fix_rmsnorm' of https://github.com/DeepLink-org/Dee…
zhangzefeng92 27da2ec
modify rms norm
zhangzefeng92 3f894a2
modify rms norm
zhangzefeng92 a5590c8
Merge branch 'main' into zzf/fix_rmsnorm
yangbofun 52bc927
lint
yangbofun 5ac11ca
delete the duplicated
yangbofun 4a31673
delete
yangbofun 8ab1be0
Update __init__.py
zhangzefeng92 c314b13
modify test
yangbofun 645c8da
Merge branch 'zzf/fix_rmsnorm' of https://github.com/DeepLink-org/Dee…
yangbofun dc674ca
modify
yangbofun 7ce2b22
modify rotary_embeding
yangbofun 03d5992
modify rotary_embeding
yangbofun 92bcf46
modify rotary_embeding
yangbofun b0965fb
lint
yangbofun e979201
fix
yangbofun a3cd9da
fix
yangbofun ca0b33c
modify mha
yangbofun 6d798aa
rename rotary_embedding
yangbofun 3380cfb
lint
yangbofun File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .deeplink import rms_norm_out, rms_norm, rms_norm_backward_out, rms_norm_backward | ||
|
||
|
||
__all__ = ["rms_norm_out", "rms_norm", "rms_norm_backward_out", "rms_norm_backward"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import torch | ||
import deeplink_ext.cpp_extensions as cpp_ext | ||
|
||
|
||
def rms_norm_out(output, inv_rms, input, normalized_shape, weight, bias, eps): | ||
if None == normalized_shape: | ||
cpp_ext.rms_norm(output, inv_rms, input, weight.shape, weight, bias, eps) | ||
else: | ||
cpp_ext.rms_norm(output, inv_rms, input, normalized_shape, weight, bias, eps) | ||
|
||
|
||
def rms_norm(input, normalized_shape, weight, bias, eps): | ||
output = torch.empty_like(input) | ||
inv_rms_shape = list(input.shape[:-1]) + [1] | ||
inv_rms = torch.empty(inv_rms_shape, dtype=input.dtype, device=input.device) | ||
rms_norm_out(output, inv_rms, input, normalized_shape, weight, bias, eps) | ||
|
||
return [output, inv_rms] | ||
|
||
|
||
def rms_norm_backward_out( | ||
grad_input, | ||
grad_weight, | ||
grad_bias, | ||
grad_output, | ||
input, | ||
weight, | ||
bias, | ||
inv_rms, | ||
normalized_shape, | ||
eps, | ||
): | ||
if None == normalized_shape: | ||
cpp_ext.rms_norm_backward( | ||
grad_input, | ||
grad_weight, | ||
grad_bias, | ||
grad_output, | ||
input, | ||
weight, | ||
bias, | ||
inv_rms, | ||
weight.shape, | ||
eps, | ||
) | ||
else: | ||
cpp_ext.rms_norm_backward( | ||
grad_input, | ||
grad_weight, | ||
grad_bias, | ||
grad_output, | ||
input, | ||
weight, | ||
bias, | ||
inv_rms, | ||
normalized_shape, | ||
eps, | ||
) | ||
|
||
|
||
def rms_norm_backward(input, grad_output, inv_rms, normalized_shape, weight, bias, eps): | ||
grad_input = torch.empty_like(input) | ||
grad_weight = torch.empty_like(weight) | ||
grad_bias = torch.empty_like(bias) | ||
rms_norm_backward_out( | ||
grad_input, | ||
grad_weight, | ||
grad_bias, | ||
grad_output, | ||
input, | ||
weight, | ||
bias, | ||
inv_rms, | ||
normalized_shape, | ||
eps, | ||
) | ||
|
||
return [grad_input, grad_weight, grad_bias] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,40 @@ | ||
# Copyright (c) 2024, DeepLink. | ||
|
||
from . import mha, rms_norm, rotary | ||
from . import mha | ||
|
||
__all__ = ["mha", "rms_norm", "rotary"] | ||
|
||
_not_impl = "[deeplink_ext] {op_name} is not implemented in diopi. Falling back to the slower torch implementation." | ||
|
||
|
||
try: | ||
from .rms_norm import RMSNorm, RMSNormWithNormalizedShape | ||
except: | ||
print( | ||
_not_impl.format(op_name="RMSNorm or RMSNormWithNormalizedShape"), | ||
) | ||
from .rms_norm_fallback import ( | ||
RMSNorm, | ||
RMSNormWithNormalizedShape, | ||
) | ||
|
||
|
||
try: | ||
from .rotary_embedding import apply_rotary | ||
except: | ||
print(_not_impl.format(op_name="apply_rotary")) | ||
from .rotary_embeddinig_fallback import apply_rotary | ||
|
||
|
||
try: | ||
from .mha import SelfAttention, CrossAttention | ||
except Exception as e: | ||
print(_not_impl.format(op_name="mha")) | ||
from .mha_fallback import SelfAttention, CrossAttention | ||
|
||
__all__ = [ | ||
"SelfAttention", | ||
"CrossAttention", | ||
"RMSNorm", | ||
"RMSNormWithNormalizedShape", | ||
"apply_rotary", | ||
] |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个签名需要修改
那个头文件可以完全删除