Skip to content

Commit

Permalink
refactor python code
Browse files Browse the repository at this point in the history
refactor python code by ops
  • Loading branch information
Wrench-Git committed Nov 5, 2024
1 parent 9eb52a2 commit 0b6718c
Show file tree
Hide file tree
Showing 43 changed files with 187 additions and 1,146 deletions.
40 changes: 11 additions & 29 deletions deeplink_ext/easyllm_ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,40 +3,22 @@
_not_impl = "[deeplink_ext] {op_name} is not implemented in diopi. Falling back to the slower torch implementation."

try:
from .adamw import AdamW
from deeplink_ext.ops.adamw import AdamW
except Exception as e:
print(_not_impl.format(op_name="adamw"))
from torch.optim import AdamW

try:
from .flash_attention import (
flash_attn_qkvpacked_func,
flash_attn_kvpacked_func,
flash_attn_func,
flash_attn_varlen_qkvpacked_func,
flash_attn_varlen_kvpacked_func,
flash_attn_varlen_func,
)
except Exception as e:
print(_not_impl.format(op_name="flash attention"))
from .flash_attention_fallback import (
flash_attn_qkvpacked_func_torch as flash_attn_qkvpacked_func,
flash_attn_kvpacked_func_torch as flash_attn_kvpacked_func,
flash_attn_func_torch as flash_attn_func,
flash_attn_varlen_qkvpacked_func_torch as flash_attn_varlen_qkvpacked_func,
flash_attn_varlen_kvpacked_func_torch as flash_attn_varlen_kvpacked_func,
flash_attn_varlen_func_torch as flash_attn_varlen_func,
)

try:
from .rms_norm import rms_norm
except:
print(
_not_impl.format(op_name="RMSNorm"),
)
from .rms_norm_fallback import rms_norm_torch as rms_norm
from deeplink_ext.ops.flash_attention import (
flash_attn_qkvpacked_func,
flash_attn_kvpacked_func,
flash_attn_func,
flash_attn_varlen_qkvpacked_func,
flash_attn_varlen_kvpacked_func,
flash_attn_varlen_func,
)

from .bert_padding import pad_input, unpad_input, index_first_axis
from deeplink_ext.ops.rms_norm import rms_norm
from deeplink_ext.ops.bert_padding import pad_input, unpad_input, index_first_axis

__all__ = [
"AdamW",
Expand Down
5 changes: 0 additions & 5 deletions deeplink_ext/easyllm_ops/adamw.py

This file was deleted.

19 changes: 0 additions & 19 deletions deeplink_ext/easyllm_ops/flash_attention.py

This file was deleted.

13 changes: 0 additions & 13 deletions deeplink_ext/easyllm_ops/rms_norm_fallback.py

This file was deleted.

45 changes: 11 additions & 34 deletions deeplink_ext/internevo_ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,23 @@
# Copyright (c) 2024, DeepLink.

_not_impl = "[deeplink_ext] {op_name} is not implemented in diopi. Falling back to the slower torch implementation."

try:
from .adamw import AdamW
from deeplink_ext.ops.adamw import AdamW
except Exception as e:
print(_not_impl.format(op_name="adamw"))
from torch.optim import AdamW

try:
from .flash_attention import (
flash_attn_qkvpacked_func,
flash_attn_kvpacked_func,
flash_attn_func,
flash_attn_varlen_qkvpacked_func,
flash_attn_varlen_kvpacked_func,
flash_attn_varlen_func,
)
except Exception as e:
print(_not_impl.format(op_name="flash attention"))
from .flash_attention_fallback import (
flash_attn_qkvpacked_func_torch as flash_attn_qkvpacked_func,
flash_attn_kvpacked_func_torch as flash_attn_kvpacked_func,
flash_attn_func_torch as flash_attn_func,
flash_attn_varlen_qkvpacked_func_torch as flash_attn_varlen_qkvpacked_func,
flash_attn_varlen_kvpacked_func_torch as flash_attn_varlen_kvpacked_func,
flash_attn_varlen_func_torch as flash_attn_varlen_func,
)
from deeplink_ext.ops.flash_attention import (
flash_attn_qkvpacked_func,
flash_attn_kvpacked_func,
flash_attn_func,
flash_attn_varlen_qkvpacked_func,
flash_attn_varlen_kvpacked_func,
flash_attn_varlen_func,
)

try:
from .rms_norm import MixedFusedRMSNorm
except:
print(
_not_impl.format(op_name="RMSNorm"),
)
from .rms_norm_fallback import MixedRMSNormTorch as MixedFusedRMSNorm
from deeplink_ext.ops.rms_norm import MixedFusedRMSNorm

try:
from .rotary_embedding import ApplyRotaryEmb
except:
print(_not_impl.format(op_name="rotary embedding"))
from .rotary_embedding_fallback import ApplyRotaryEmbTorch as ApplyRotaryEmb
from deeplink_ext.ops.rotary_embedding import ApplyRotaryEmb

__all__ = [
"AdamW",
Expand Down
Loading

0 comments on commit 0b6718c

Please sign in to comment.