Skip to content

Commit

Permalink
handle patch failure -> fallback & add Copyright
Browse files Browse the repository at this point in the history
  • Loading branch information
lljbash committed Jan 24, 2024
1 parent c0df289 commit b25f0e0
Show file tree
Hide file tree
Showing 21 changed files with 80 additions and 7 deletions.
2 changes: 2 additions & 0 deletions deeplink_ext/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Copyright (c) 2024, DeepLink.

def _init():
# deeplink_ext is developed based on dipu
# so we explicitly import torch_dipu to guarantees that torch is patched by dipu
Expand Down
2 changes: 2 additions & 0 deletions deeplink_ext/internlm_ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
# Copyright (c) 2024, DeepLink.

from . import mha, rms_norm, rotary
11 changes: 10 additions & 1 deletion deeplink_ext/internlm_ops/mha/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
# Copyright (c) 2023, DeepLink.

from .mha import DeepLinkSelfAttention, DeepLinkCrossAttention
try:
from .mha import DeepLinkSelfAttention, DeepLinkCrossAttention
except Exception as e:
print(
"[deeplink_ext] mha is not implemented in diopi. Falling back to the slower implementation."
)
from .fallback import (
SelfAttention as DeepLinkSelfAttention,
CrossAttention as DeepLinkCrossAttention,
)
from . import fallback
2 changes: 2 additions & 0 deletions deeplink_ext/internlm_ops/mha/fallback/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
# Copyright (c) 2024, DeepLink.

from .fallback import SelfAttention, CrossAttention
2 changes: 2 additions & 0 deletions deeplink_ext/internlm_ops/mha/mha_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import torch
import deeplink_ext.cpp_extensions as ext

assert hasattr(ext, "mha_fwd") and hasattr(ext, "mha_bwd")


class DeepLinkMultiHeadAttentionFunc(torch.autograd.Function):
@staticmethod
Expand Down
2 changes: 2 additions & 0 deletions deeplink_ext/internlm_ops/mha/mha_kvpacked_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import torch
import deeplink_ext.cpp_extensions as ext

assert hasattr(ext, "mha_fwd") and hasattr(ext, "mha_bwd")


class DeepLinkMultiHeadAttentionKVPackedFunc(torch.autograd.Function):
@staticmethod
Expand Down
2 changes: 2 additions & 0 deletions deeplink_ext/internlm_ops/mha/mha_qkvpacked_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import torch
import deeplink_ext.cpp_extensions as ext

assert hasattr(ext, "mha_fwd") and hasattr(ext, "mha_bwd")


class DeepLinkMultiHeadAttentionQKVPackedFunc(torch.autograd.Function):
@staticmethod
Expand Down
2 changes: 2 additions & 0 deletions deeplink_ext/internlm_ops/mha/mha_varlen_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import torch
import deeplink_ext.cpp_extensions as ext

assert hasattr(ext, "mha_fwd") and hasattr(ext, "mha_bwd")


class DeepLinkMultiHeadAttentionVarLenFunc(torch.autograd.Function):
@staticmethod
Expand Down
2 changes: 2 additions & 0 deletions deeplink_ext/internlm_ops/mha/mha_varlen_kvpacked_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import torch
import deeplink_ext.cpp_extensions as ext

assert hasattr(ext, "mha_fwd") and hasattr(ext, "mha_bwd")


class DeepLinkMultiHeadAttentionVarLenKVPackedFunc(torch.autograd.Function):
@staticmethod
Expand Down
2 changes: 2 additions & 0 deletions deeplink_ext/internlm_ops/mha/mha_varlen_qkvpacked_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import torch
import deeplink_ext.cpp_extensions as ext

assert hasattr(ext, "mha_fwd") and hasattr(ext, "mha_bwd")


class DeepLinkMultiHeadAttentionVarLenQKVPackedFunc(torch.autograd.Function):
@staticmethod
Expand Down
10 changes: 9 additions & 1 deletion deeplink_ext/internlm_ops/rms_norm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,10 @@
from .deeplink import DeepLinkRMSNorm, DeepLinkRMSNormWithNormalizedShape
# Copyright (c) 2024, DeepLink.

try:
from .deeplink import DeepLinkRMSNorm, DeepLinkRMSNormWithNormalizedShape
except:
print(
"[deeplink_ext] rms_norm is not implemented in diopi. Falling back to the slower implementation."
)
from .fallback import RMSNorm as DeepLinkRMSNorm
from . import fallback
4 changes: 4 additions & 0 deletions deeplink_ext/internlm_ops/rms_norm/deeplink.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Copyright (c) 2024, DeepLink.

import torch
import deeplink_ext.cpp_extensions as ext

assert hasattr(ext, "rms_norm")


# 定义自定义的 autograd 函数
class _DeepLinkRMSNormFunction(torch.autograd.Function):
Expand Down
2 changes: 2 additions & 0 deletions deeplink_ext/internlm_ops/rms_norm/fallback/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
# Copyright (c) 2024, DeepLink.

from .fallback import RMSNorm
2 changes: 2 additions & 0 deletions deeplink_ext/internlm_ops/rms_norm/fallback/fallback.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Copyright (c) 2024, DeepLink.

import torch


Expand Down
13 changes: 12 additions & 1 deletion deeplink_ext/internlm_ops/rotary/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,13 @@
from .deeplink import DeepLinkApplyRotaryEmb, DeepLinkApplyRotaryEmbQKV_
# Copyright (c) 2024, DeepLink.

try:
from .deeplink import DeepLinkApplyRotaryEmb, DeepLinkApplyRotaryEmbQKV_
except:
print(
"[deeplink_ext] rotary is not implemented in diopi. Falling back to the slower implementation."
)
from .fallback import (
ApplyRotaryEmb as DeepLinkApplyRotaryEmb,
ApplyRotaryEmbQKV_ as DeepLinkApplyRotaryEmbQKV_,
)
from . import fallback
4 changes: 4 additions & 0 deletions deeplink_ext/internlm_ops/rotary/deeplink.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
# Copyright (c) 2024, DeepLink.

import torch
from einops import rearrange
import deeplink_ext.cpp_extensions as ext

assert hasattr(ext, "apply_rotary")


class DeepLinkApplyRotaryEmbQKV_(torch.autograd.Function):
@staticmethod
Expand Down
2 changes: 2 additions & 0 deletions deeplink_ext/internlm_ops/rotary/fallback/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
# Copyright (c) 2024, DeepLink.

from .fallback import ApplyRotaryEmb, ApplyRotaryEmbQKV_
2 changes: 2 additions & 0 deletions deeplink_ext/internlm_ops/rotary/fallback/fallback.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Copyright (c) 2024, DeepLink.

import torch
from einops import rearrange
import deeplink_ext.cpp_extensions as ext
Expand Down
3 changes: 3 additions & 0 deletions deeplink_ext/patch_internlm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Copyright (c) 2024, DeepLink.

def _patch_internlm():
import importlib.util
import os
Expand Down Expand Up @@ -67,6 +69,7 @@ def _patch_ops():
_find_flash_attn()
_patch_flash_attn()
_patch_ops()
print("[deeplink_ext] patched diopi implementation of internlm")


_patch_internlm()
14 changes: 10 additions & 4 deletions deeplink_ext/patch_lightllm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Copyright (c) 2024, DeepLink.

def _patch_lightllm():
import os
import deeplink_ext.cpp_extensions as ext
Expand Down Expand Up @@ -25,7 +27,7 @@ def _patch_lightllm():
patch_list_env.split(",") if use_custom_patch_list else DEFAULT_PATCH_LIST
)
if use_custom_patch_list:
print(f"Use custom lightllm patch list: {patch_list}")
print(f"[deeplink_ext] use custom lightllm patch list: {patch_list}")

def try_patch(op: str):
def patch_dest_index_copy_kv():
Expand Down Expand Up @@ -55,9 +57,13 @@ def patch_rotary_emb():

try:
locals()[f"patch_{op}"]()
print(f"Patched diopi implementation of {op}")
except:
print(f"Unknow op: {op}, supported ops: {DEFAULT_PATCH_LIST}")
print(f"[deeplink_ext] patched diopi implementation of {op}")
except KeyError:
print(
f"[deeplink_ext] unknow op: {op}, supported ops: {DEFAULT_PATCH_LIST}"
)
except AttributeError:
print(f"[deeplink_ext] op {op} is not implemented in diopi")

for op in patch_list:
try_patch(op)
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Copyright (c) 2024, DeepLink.

from setuptools import find_packages, setup, Extension
from torch.utils.cpp_extension import BuildExtension, include_paths, library_paths
import glob
Expand Down

0 comments on commit b25f0e0

Please sign in to comment.