diff --git a/.clang-format b/.clang-format index 9d13f969..ed68e746 100644 --- a/.clang-format +++ b/.clang-format @@ -1,275 +1,41 @@ --- Language: Cpp BasedOnStyle: Google -# AccessModifierOffset: -1 -# AlignAfterOpenBracket: Align -# AlignArrayOfStructures: None -# AlignConsecutiveAssignments: -# Enabled: false -# AcrossEmptyLines: false -# AcrossComments: false -# AlignCompound: false -# PadOperators: true -# AlignConsecutiveBitFields: -# Enabled: false -# AcrossEmptyLines: false -# AcrossComments: false -# AlignCompound: false -# PadOperators: false -# AlignConsecutiveDeclarations: -# Enabled: false -# AcrossEmptyLines: false -# AcrossComments: false -# AlignCompound: false -# PadOperators: false -# AlignConsecutiveMacros: -# Enabled: false -# AcrossEmptyLines: false -# AcrossComments: false -# AlignCompound: false -# PadOperators: false -# AlignEscapedNewlines: Left -# AlignOperands: Align -# AlignTrailingComments: -# Kind: Always -# OverEmptyLines: 0 -# AllowAllArgumentsOnNextLine: true -# AllowAllParametersOfDeclarationOnNextLine: true -# AllowShortBlocksOnASingleLine: Never -# AllowShortCaseLabelsOnASingleLine: false -# AllowShortEnumsOnASingleLine: true -# AllowShortFunctionsOnASingleLine: All -# AllowShortIfStatementsOnASingleLine: WithoutElse -# AllowShortLambdasOnASingleLine: All -# AllowShortLoopsOnASingleLine: true -# AlwaysBreakAfterDefinitionReturnType: None -# AlwaysBreakAfterReturnType: None -# AlwaysBreakBeforeMultilineStrings: true -# AlwaysBreakTemplateDeclarations: Yes -# AttributeMacros: -# - __capability -# BinPackArguments: true -# BinPackParameters: true -# BitFieldColonSpacing: Both -# BraceWrapping: -# AfterCaseLabel: false -# AfterClass: false -# AfterControlStatement: Never -# AfterEnum: false -# AfterExternBlock: false -# AfterFunction: false -# AfterNamespace: false -# AfterObjCDeclaration: false -# AfterStruct: false -# AfterUnion: false -# BeforeCatch: false -# BeforeElse: false -# BeforeLambdaBody: false -# BeforeWhile: false -# IndentBraces: false -# SplitEmptyFunction: true -# SplitEmptyRecord: true -# SplitEmptyNamespace: true -# BreakAfterAttributes: Never -# BreakAfterJavaFieldAnnotations: false -# BreakArrays: true -# BreakBeforeBinaryOperators: None -# BreakBeforeConceptDeclarations: Always -# BreakBeforeBraces: Attach -# BreakBeforeInlineASMColon: OnlyMultiline -# BreakBeforeTernaryOperators: true -# BreakConstructorInitializers: BeforeColon -# BreakInheritanceList: BeforeColon -# BreakStringLiterals: true -# ColumnLimit: 80 -# CommentPragmas: '^ IWYU pragma:' -# CompactNamespaces: false -# ConstructorInitializerIndentWidth: 4 -# ContinuationIndentWidth: 4 -# Cpp11BracedListStyle: true -# DerivePointerAlignment: true -# DisableFormat: false -# EmptyLineAfterAccessModifier: Never -# EmptyLineBeforeAccessModifier: LogicalBlock -# ExperimentalAutoDetectBinPacking: false -# FixNamespaceComments: true -# ForEachMacros: -# - foreach -# - Q_FOREACH -# - BOOST_FOREACH -# IfMacros: -# - KJ_IF_MAYBE -# IncludeBlocks: Regroup +BreakAfterAttributes: Leave +CommentPragmas: '^ (IWYU pragma:|NOLINT(BEGIN|END|NEXTLINE)?(\(.+\))?:? )' +DerivePointerAlignment: false +InsertNewlineAtEOF: true IncludeCategories: - - Regex: '^("|<)csrc_dipu/.*' - Priority: 9 - SortPriority: 0 - CaseSensitive: false - - Regex: '^("|<)diopi/.*' - Priority: 8 - SortPriority: 0 - CaseSensitive: false - - Regex: '^("|<)(c10|aten|torch).*' - Priority: 4 - SortPriority: 0 - CaseSensitive: false - - Regex: '^("|<)(pybind11|Python\.h).*' - Priority: 5 - SortPriority: 0 - CaseSensitive: false - - Regex: '^<((ext/.*)|pthread)\.h.*' - Priority: 2 - SortPriority: 1 - CaseSensitive: false - - Regex: '^("|<)(cuda|su|cn|(..?ccl)|(.*_runtime)).*\.h.*' - Priority: 3 - SortPriority: 0 - CaseSensitive: false - - Regex: '^<.*' - Priority: 2 - SortPriority: 0 - CaseSensitive: false - - Regex: '.*' - Priority: 10 - SortPriority: 0 - CaseSensitive: false -# IncludeIsMainRegex: '([-_](test|unittest))?$' -# IncludeIsMainSourceRegex: '' -# IndentAccessModifiers: false -# IndentCaseBlocks: false -# IndentCaseLabels: true -# IndentExternBlock: AfterExternBlock -# IndentGotoLabels: true -# IndentPPDirectives: None -# IndentRequiresClause: true -# IndentWidth: 2 -# IndentWrappedFunctionNames: false -# InsertBraces: false -# InsertNewlineAtEOF: false -# InsertTrailingCommas: None -# IntegerLiteralSeparator: -# Binary: 0 -# BinaryMinDigits: 0 -# Decimal: 0 -# DecimalMinDigits: 0 -# Hex: 0 -# HexMinDigits: 0 -# JavaScriptQuotes: Leave -# JavaScriptWrapImports: true -# KeepEmptyLinesAtTheStartOfBlocks: false -# LambdaBodyIndentation: Signature -# LineEnding: DeriveLF -# MacroBlockBegin: '' -# MacroBlockEnd: '' -# MaxEmptyLinesToKeep: 1 -# NamespaceIndentation: None -# ObjCBinPackProtocolList: Never -# ObjCBlockIndentWidth: 2 -# ObjCBreakBeforeNestedBlockParam: true -# ObjCSpaceAfterProperty: false -# ObjCSpaceBeforeProtocolList: true -# PackConstructorInitializers: NextLine -# PenaltyBreakAssignment: 2 -# PenaltyBreakBeforeFirstCallParameter: 1 -# PenaltyBreakComment: 300 -# PenaltyBreakFirstLessLess: 120 -# PenaltyBreakOpenParenthesis: 0 -# PenaltyBreakString: 1000 -# PenaltyBreakTemplateDeclaration: 10 -# PenaltyExcessCharacter: 1000000 -# PenaltyIndentedWhitespace: 0 -# PenaltyReturnTypeOnItsOwnLine: 200 -# PointerAlignment: Left -# PPIndentWidth: -1 -# QualifierAlignment: Leave -# RawStringFormats: -# - Language: Cpp -# Delimiters: -# - cc -# - CC -# - cpp -# - Cpp -# - CPP -# - 'c++' -# - 'C++' -# CanonicalDelimiter: '' -# BasedOnStyle: google -# - Language: TextProto -# Delimiters: -# - pb -# - PB -# - proto -# - PROTO -# EnclosingFunctions: -# - EqualsProto -# - EquivToProto -# - PARSE_PARTIAL_TEXT_PROTO -# - PARSE_TEST_PROTO -# - PARSE_TEXT_PROTO -# - ParseTextOrDie -# - ParseTextProtoOrDie -# - ParseTestProto -# - ParsePartialTestProto -# CanonicalDelimiter: pb -# BasedOnStyle: google -# ReferenceAlignment: Pointer -# ReflowComments: true -# RemoveBracesLLVM: false -# RemoveSemicolon: false -# RequiresClausePosition: OwnLine -# RequiresExpressionIndentation: OuterScope -# SeparateDefinitionBlocks: Leave -# ShortNamespaceLines: 1 -# SortIncludes: CaseSensitive -# SortJavaStaticImport: Before -# SortUsingDeclarations: LexicographicNumeric -# SpaceAfterCStyleCast: false -# SpaceAfterLogicalNot: false -# SpaceAfterTemplateKeyword: true -# SpaceAroundPointerQualifiers: Default -# SpaceBeforeAssignmentOperators: true -# SpaceBeforeCaseColon: false -# SpaceBeforeCpp11BracedList: false -# SpaceBeforeCtorInitializerColon: true -# SpaceBeforeInheritanceColon: true -# SpaceBeforeParens: ControlStatements -# SpaceBeforeParensOptions: -# AfterControlStatements: true -# AfterForeachMacros: true -# AfterFunctionDefinitionName: false -# AfterFunctionDeclarationName: false -# AfterIfMacros: true -# AfterOverloadedOperator: false -# AfterRequiresInClause: false -# AfterRequiresInExpression: false -# BeforeNonEmptyParentheses: false -# SpaceBeforeRangeBasedForLoopColon: true -# SpaceBeforeSquareBrackets: false -# SpaceInEmptyBlock: false -# SpaceInEmptyParentheses: false -# SpacesBeforeTrailingComments: 2 -# SpacesInAngles: Never -# SpacesInConditionalStatement: false -# SpacesInContainerLiterals: true -# SpacesInCStyleCastParentheses: false -# SpacesInLineCommentPrefix: -# Minimum: 1 -# Maximum: -1 -# SpacesInParentheses: false -# SpacesInSquareBrackets: false -# Standard: Auto -# StatementAttributeLikeMacros: -# - Q_EMIT -# StatementMacros: -# - Q_UNUSED -# - QT_REQUIRE_VERSION -# TabWidth: 8 -# UseTab: Never -# WhitespaceSensitiveMacros: -# - BOOST_PP_STRINGIZE -# - CF_SWIFT_NAME -# - NS_SWIFT_NAME -# - PP_STRINGIZE -# - STRINGIZE -... - + - Regex: '^("|<)csrc_dipu/' + Priority: 90 + CaseSensitive: false + - Regex: '^("|<)diopi/' + Priority: 80 + CaseSensitive: false + - Regex: '^("|<)(c10|aten|torch)/' + Priority: 40 + CaseSensitive: false + - Regex: '^("|<)Python\.h' + Priority: 50 + CaseSensitive: false + - Regex: '^("|<)(frameobject|structmember)\.h' + Priority: 50 + SortPriority: 51 + CaseSensitive: false + - Regex: '^("|<)(pybind11)' + Priority: 50 + SortPriority: 52 + CaseSensitive: false + - Regex: '^<((ext/.*)|pthread)\.h' + Priority: 20 + SortPriority: 21 + CaseSensitive: false + - Regex: '^("|<)(cuda|su|cn|(..?ccl)|(.*_runtime)).*\.h' + Priority: 30 + CaseSensitive: false + - Regex: '^<.*' + Priority: 20 + CaseSensitive: false + - Regex: '.*' + Priority: 100 + CaseSensitive: false diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml new file mode 100644 index 00000000..c44af75f --- /dev/null +++ b/.github/workflows/format.yml @@ -0,0 +1,32 @@ +name: format + +on: + workflow_dispatch: + pull_request: + push: + branches: + - main + +jobs: + clang-format: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: cpp-linter/cpp-linter-action@v2 + id: cpp-lint + with: + style: file + tidy-checks: '-*' # disable clang tidy at this stage + version: 17 + files-changed-only: false + - name: Fail test + if: steps.cpp-lint.outputs.checks-failed > 0 + run: echo "Some files failed the linting checks!" && exit 1 + + python-black: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: psf/black@stable + with: # see: https://black.readthedocs.io/en/stable/getting_started.html + version: "~= 23.11.0" diff --git a/.gitignore b/.gitignore index 84497213..79e17dfb 100644 --- a/.gitignore +++ b/.gitignore @@ -1,26 +1,49 @@ +### C++ ### +# Prerequisites +*.d + +# Compiled Object files +*.slo +*.lo +*.o +*.obj + +# Precompiled Headers +*.gch +*.pch + +# Compiled Dynamic libraries +*.so +*.dylib +*.dll + +# Fortran module files +*.mod +*.smod + +# Compiled Static libraries +*.lai +*.la +*.a +*.lib + +# Executables +*.exe +*.out +*.app + +### Python ### # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class -# scripts -build.sh - # C extensions *.so -# autogened file -torch_dipu/csrc_dipu/aten/ops/*AutoGenedKernels* - -# Patch / Merge -*.orig -*.reg - # Distribution / packaging .Python -env/ build/ -build* develop-eggs/ dist/ downloads/ @@ -28,12 +51,15 @@ eggs/ .eggs/ lib/ lib64/ +parts/ sdist/ var/ wheels/ +share/python-wheels/ *.egg-info/ .installed.cfg *.egg +MANIFEST # PyInstaller # Usually these files are written by a python script from a template @@ -48,13 +74,17 @@ pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ +.nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover +*.py,cover .hypothesis/ +.pytest_cache/ +cover/ # Translations *.mo @@ -63,6 +93,8 @@ coverage.xml # Django stuff: *.log local_settings.py +db.sqlite3 +db.sqlite3-journal # Flask stuff: instance/ @@ -71,25 +103,65 @@ instance/ # Scrapy stuff: .scrapy +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + # Jupyter Notebook .ipynb_checkpoints -# pyenv -.python-version +# IPython +profile_default/ +ipython_config.py -# celery beat schedule file +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff celerybeat-schedule +celerybeat.pid # SageMath parsed files *.sage.py -# dotenv +# Environments .env - -# virtualenv .venv +env/ venv/ ENV/ +env.bak/ +venv.bak/ # Spyder project settings .spyderproject @@ -103,20 +175,68 @@ ENV/ # mypy .mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +### Python Patch ### +# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration +poetry.toml + +# ruff +.ruff_cache/ + +# LSP config files +pyrightconfig.json + +### Linux ### +*~ + +# temporary files which can be created if a process still has a handle open of a deleted file +.fuse_hidden* + +# KDE directory preferences +.directory + +# Linux trash folder which might appear on any partition or disk +.Trash-* -# vi -*.swp -*.swo +# .nfs files are created when an open file is removed but is still being accessed +.nfs* -# JetBrains CLion -.idea/ -cmake-build-*/ +### Vim ### +# Swap +[._]*.s[a-v][a-z] +!*.svg # comment out if you don't need vector files +[._]*.sw[a-p] +[._]s[a-rt-v][a-z] +[._]ss[a-gi-z] +[._]sw[a-p] -# macOS system files -.DS_Store +# Session +Session.vim +Sessionx.vim -# VS Code workspace configuration -.vscode +# Temporary +.netrwhist +*~ +# Auto-generated tag files +tags +# Persistent undo +[._]*.un~ -#.core file -core.* diff --git a/README.md b/README.md index b303b9ae..838f7352 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,31 @@ # DeepLinkExt -基本思想仿照cpp extension,不过会先在python层判断该融合算子的diopi实现没有(具体判断方法为,在cpp层进行pybind时,如果没有diopi实现,则不进行pybind)。如果没有实现,则会在python层替换为torch的几个分离算子。 +基本思想仿照 cpp extension,不过会先在 python 层判断该融合算子的 diopi 实现没有(具体判断方法为,在 cpp 层进行 pybind 时,如果没有 diopi 实现,则不进行 pybind)。如果没有实现,则会在 python 层替换为 torch 的几个分离算子。 -融合算子的diopi定义及实现放在DIOPI库里,本拓展库仅引用。 \ No newline at end of file +融合算子的 diopi 定义及实现放在 DIOPI 库里,本拓展库仅引用。 + +支持自动 patch InternLM 和 lightllm 中用到的融合算子,将它们替换为 DIOPI 实现。 + +## Install + +首先安装 DIPU,确保可以 `import torch_dipu`。然后在本目录下执行 + +```bash +pip install -e . +``` + +## Usage + +### InternLM + +```python +import deeplink_ext.patch_internlm +import internlm +``` + +#lightllm + +```python +import deeplink_ext.patch_lightllm +import lightllm +``` diff --git a/__init__.py b/__init__.py deleted file mode 100644 index 123ec8a6..00000000 --- a/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -import sys -import os -cur_work_dir = os.path.dirname(__file__) -sys.path.append(cur_work_dir) -from .ext_apply.lightllm.mock_op import * \ No newline at end of file diff --git a/ext_op/diopi_helper.h b/csrc/diopi_helper.h similarity index 95% rename from ext_op/diopi_helper.h rename to csrc/diopi_helper.h index be769585..3400020b 100644 --- a/ext_op/diopi_helper.h +++ b/csrc/diopi_helper.h @@ -102,7 +102,8 @@ void callDiopi(DiopiFunc&& diopi_func, Args&&... args) { diopiError_t err_code = diopi_func(&ctx, castToDiopiType(std::forward(args))...); if (err_code != diopiSuccess) { - throw std::runtime_error("DIOPI call failed"); + throw std::runtime_error("DIOPI error, code: " + std::to_string(err_code) + + ", message: " + diopiGetLastErrorString()); } } diff --git a/ext_op/example_ext.cpp b/csrc/extensions.cpp similarity index 99% rename from ext_op/example_ext.cpp rename to csrc/extensions.cpp index 34713501..09c6e3e2 100644 --- a/ext_op/example_ext.cpp +++ b/csrc/extensions.cpp @@ -208,13 +208,13 @@ void extContextAttentionInference(const at::Tensor& q, const at::Tensor& k, b_seq_len, max_input_len); } -void extApplyPenalty(at::Tensor& Logits, const at::Tensor& presence_penalty, +void extApplyPenalty(at::Tensor& logits, const at::Tensor& presence_penalty, const at::Tensor& frequency_penalty, const at::Tensor& p_token_ids, const at::Tensor& p_token_counts, const at::Tensor& p_cumsum_seq_len, int p_max_len_in_batch) { - callDiopi(diopiApplyPenalty, Logits, presence_penalty, frequency_penalty, + callDiopi(diopiApplyPenalty, logits, presence_penalty, frequency_penalty, p_token_ids, p_token_counts, p_cumsum_seq_len, p_max_len_in_batch); } diff --git a/ext_op/pybind_type_cast.h b/csrc/pybind_type_cast.h similarity index 100% rename from ext_op/pybind_type_cast.h rename to csrc/pybind_type_cast.h diff --git a/deeplink_ext/__init__.py b/deeplink_ext/__init__.py new file mode 100644 index 00000000..23bc7ee9 --- /dev/null +++ b/deeplink_ext/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2024, DeepLink. + + +def _init(): + # deeplink_ext is developed based on dipu + # so we explicitly import torch_dipu to guarantees that torch is patched by dipu + import torch_dipu + + +_init() diff --git a/deeplink_ext/internlm_ops/__init__.py b/deeplink_ext/internlm_ops/__init__.py new file mode 100644 index 00000000..aa73ef5b --- /dev/null +++ b/deeplink_ext/internlm_ops/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) 2024, DeepLink. + +from . import mha, rms_norm, rotary diff --git a/deeplink_ext/internlm_ops/mha/__init__.py b/deeplink_ext/internlm_ops/mha/__init__.py new file mode 100644 index 00000000..17f90967 --- /dev/null +++ b/deeplink_ext/internlm_ops/mha/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2023, DeepLink. + +try: + from .mha import DeepLinkSelfAttention, DeepLinkCrossAttention +except Exception as e: + print( + "[deeplink_ext] mha is not implemented in diopi. Falling back to the slower implementation.\n", + end="", + ) + from .fallback import ( + SelfAttention as DeepLinkSelfAttention, + CrossAttention as DeepLinkCrossAttention, + ) +from . import fallback diff --git a/deeplink_ext/internlm_ops/mha/fallback/__init__.py b/deeplink_ext/internlm_ops/mha/fallback/__init__.py new file mode 100644 index 00000000..795ddbd9 --- /dev/null +++ b/deeplink_ext/internlm_ops/mha/fallback/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) 2024, DeepLink. + +from .fallback import SelfAttention, CrossAttention diff --git a/ext_apply/internlm/ext_mha/mha_fallback.py b/deeplink_ext/internlm_ops/mha/fallback/fallback.py similarity index 100% rename from ext_apply/internlm/ext_mha/mha_fallback.py rename to deeplink_ext/internlm_ops/mha/fallback/fallback.py diff --git a/ext_apply/internlm/ext_mha/mha.py b/deeplink_ext/internlm_ops/mha/mha.py similarity index 100% rename from ext_apply/internlm/ext_mha/mha.py rename to deeplink_ext/internlm_ops/mha/mha.py diff --git a/ext_apply/internlm/ext_mha/mha_func.py b/deeplink_ext/internlm_ops/mha/mha_func.py similarity index 86% rename from ext_apply/internlm/ext_mha/mha_func.py rename to deeplink_ext/internlm_ops/mha/mha_func.py index a6e7f6c9..3efecb5d 100644 --- a/ext_apply/internlm/ext_mha/mha_func.py +++ b/deeplink_ext/internlm_ops/mha/mha_func.py @@ -1,7 +1,9 @@ # Copyright (c) 2023, DeepLink. import torch -import dipu_ext.ext_ +import deeplink_ext.cpp_extensions as ext + +assert hasattr(ext, "mha_fwd") and hasattr(ext, "mha_bwd") class DeepLinkMultiHeadAttentionFunc(torch.autograd.Function): @@ -9,7 +11,7 @@ class DeepLinkMultiHeadAttentionFunc(torch.autograd.Function): def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, return_softmax): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) - out, softmax_lse, rng, S_dmask = dipu_ext.ext_.mha_fwd( + out, softmax_lse, rng, S_dmask = ext.mha_fwd( q, k, v, @@ -29,7 +31,7 @@ def backward(ctx, dout): q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors rng = torch.Generator(device=q.device) rng.set_state(rng_state) - dq, dk, dv = dipu_ext.ext_.mha_bwd( + dq, dk, dv = ext.mha_bwd( dout, q, k, diff --git a/ext_apply/internlm/ext_mha/mha_kvpacked_func.py b/deeplink_ext/internlm_ops/mha/mha_kvpacked_func.py similarity index 88% rename from ext_apply/internlm/ext_mha/mha_kvpacked_func.py rename to deeplink_ext/internlm_ops/mha/mha_kvpacked_func.py index a2fa34a8..33e248f1 100644 --- a/ext_apply/internlm/ext_mha/mha_kvpacked_func.py +++ b/deeplink_ext/internlm_ops/mha/mha_kvpacked_func.py @@ -1,7 +1,9 @@ # Copyright (c) 2023, DeepLink. import torch -import dipu_ext.ext_ +import deeplink_ext.cpp_extensions as ext + +assert hasattr(ext, "mha_fwd") and hasattr(ext, "mha_bwd") class DeepLinkMultiHeadAttentionKVPackedFunc(torch.autograd.Function): @@ -9,7 +11,7 @@ class DeepLinkMultiHeadAttentionKVPackedFunc(torch.autograd.Function): def forward(ctx, q, kv, dropout_p, softmax_scale, causal, return_softmax): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) - out, softmax_lse, rng, S_dmask = dipu_ext.ext_.mha_fwd( + out, softmax_lse, rng, S_dmask = ext.mha_fwd( q, kv[:, :, 0], kv[:, :, 1], @@ -31,7 +33,7 @@ def backward(ctx, dout): dkv = torch.empty_like(kv) rng = torch.Generator(device=q.device) rng.set_state(rng_state) - dipu_ext.ext_.mha_bwd( + ext.mha_bwd( dout, q, kv[:, :, 0], diff --git a/ext_apply/internlm/ext_mha/mha_qkvpacked_func.py b/deeplink_ext/internlm_ops/mha/mha_qkvpacked_func.py similarity index 88% rename from ext_apply/internlm/ext_mha/mha_qkvpacked_func.py rename to deeplink_ext/internlm_ops/mha/mha_qkvpacked_func.py index 3e4052e7..61527adb 100644 --- a/ext_apply/internlm/ext_mha/mha_qkvpacked_func.py +++ b/deeplink_ext/internlm_ops/mha/mha_qkvpacked_func.py @@ -1,7 +1,9 @@ # Copyright (c) 2023, DeepLink. import torch -import dipu_ext.ext_ +import deeplink_ext.cpp_extensions as ext + +assert hasattr(ext, "mha_fwd") and hasattr(ext, "mha_bwd") class DeepLinkMultiHeadAttentionQKVPackedFunc(torch.autograd.Function): @@ -9,7 +11,7 @@ class DeepLinkMultiHeadAttentionQKVPackedFunc(torch.autograd.Function): def forward(ctx, qkv, dropout_p, softmax_scale, causal, return_softmax): if softmax_scale is None: softmax_scale = qkv.shape[-1] ** (-0.5) - out, softmax_lse, rng, S_dmask = dipu_ext.ext_.mha_fwd( + out, softmax_lse, rng, S_dmask = ext.mha_fwd( qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], @@ -30,7 +32,7 @@ def backward(ctx, dout): dqkv = torch.empty_like(qkv) rng = torch.Generator(device=qkv.device) rng.set_state(rng_state) - dipu_ext.ext_.mha_bwd( + ext.mha_bwd( dout, qkv[:, :, 0], qkv[:, :, 1], diff --git a/ext_apply/internlm/ext_mha/mha_varlen_func.py b/deeplink_ext/internlm_ops/mha/mha_varlen_func.py similarity index 90% rename from ext_apply/internlm/ext_mha/mha_varlen_func.py rename to deeplink_ext/internlm_ops/mha/mha_varlen_func.py index 046c6077..018f2e42 100644 --- a/ext_apply/internlm/ext_mha/mha_varlen_func.py +++ b/deeplink_ext/internlm_ops/mha/mha_varlen_func.py @@ -1,7 +1,9 @@ # Copyright (c) 2023, DeepLink. import torch -import dipu_ext.ext_ +import deeplink_ext.cpp_extensions as ext + +assert hasattr(ext, "mha_fwd") and hasattr(ext, "mha_bwd") class DeepLinkMultiHeadAttentionVarLenFunc(torch.autograd.Function): @@ -22,7 +24,7 @@ def forward( ): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) - out, softmax_lse, rng, S_dmask = dipu_ext.ext_.mha_varlen_fwd( + out, softmax_lse, rng, S_dmask = ext.mha_varlen_fwd( q, k, v, @@ -59,7 +61,7 @@ def backward(ctx, dout): ) = ctx.saved_tensors rng = torch.Generator(device=q.device) rng.set_state(rng_state) - dq, dk, dv = dipu_ext.ext_.mha_varlen_bwd( + dq, dk, dv = ext.mha_varlen_bwd( dout, q, k, diff --git a/ext_apply/internlm/ext_mha/mha_varlen_kvpacked_func.py b/deeplink_ext/internlm_ops/mha/mha_varlen_kvpacked_func.py similarity index 91% rename from ext_apply/internlm/ext_mha/mha_varlen_kvpacked_func.py rename to deeplink_ext/internlm_ops/mha/mha_varlen_kvpacked_func.py index 544bdcb4..4c0c3a70 100644 --- a/ext_apply/internlm/ext_mha/mha_varlen_kvpacked_func.py +++ b/deeplink_ext/internlm_ops/mha/mha_varlen_kvpacked_func.py @@ -1,7 +1,9 @@ # Copyright (c) 2023, DeepLink. import torch -import dipu_ext.ext_ +import deeplink_ext.cpp_extensions as ext + +assert hasattr(ext, "mha_fwd") and hasattr(ext, "mha_bwd") class DeepLinkMultiHeadAttentionVarLenKVPackedFunc(torch.autograd.Function): @@ -21,7 +23,7 @@ def forward( ): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) - out, softmax_lse, rng, S_dmask = dipu_ext.ext_.mha_varlen_fwd( + out, softmax_lse, rng, S_dmask = ext.mha_varlen_fwd( q, kv[:, :, 0], kv[:, :, 1], @@ -58,7 +60,7 @@ def backward(ctx, dout): dkv = torch.empty_like(kv) rng = torch.Generator(device=q.device) rng.set_state(rng_state) - dipu_ext.ext_.mha_varlen_bwd( + ext.mha_varlen_bwd( dout, q, kv[:, :, 0], diff --git a/ext_apply/internlm/ext_mha/mha_varlen_qkvpacked_func.py b/deeplink_ext/internlm_ops/mha/mha_varlen_qkvpacked_func.py similarity index 89% rename from ext_apply/internlm/ext_mha/mha_varlen_qkvpacked_func.py rename to deeplink_ext/internlm_ops/mha/mha_varlen_qkvpacked_func.py index 56562589..d42e0f7c 100644 --- a/ext_apply/internlm/ext_mha/mha_varlen_qkvpacked_func.py +++ b/deeplink_ext/internlm_ops/mha/mha_varlen_qkvpacked_func.py @@ -1,7 +1,9 @@ # Copyright (c) 2023, DeepLink. import torch -import dipu_ext.ext_ +import deeplink_ext.cpp_extensions as ext + +assert hasattr(ext, "mha_fwd") and hasattr(ext, "mha_bwd") class DeepLinkMultiHeadAttentionVarLenQKVPackedFunc(torch.autograd.Function): @@ -18,7 +20,7 @@ def forward( ): if softmax_scale is None: softmax_scale = qkv.shape[-1] ** (-0.5) - out, softmax_lse, rng, S_dmask = dipu_ext.ext_.mha_varlen_fwd( + out, softmax_lse, rng, S_dmask = ext.mha_varlen_fwd( qkv[:, 0], qkv[:, 1], qkv[:, 2], @@ -44,7 +46,7 @@ def backward(ctx, dout): dqkv = torch.empty_like(qkv) rng = torch.Generator(device=qkv.device) rng.set_state(rng_state) - dipu_ext.ext_.mha_varlen_bwd( + ext.mha_varlen_bwd( dout, qkv[:, 0], qkv[:, 1], diff --git a/deeplink_ext/internlm_ops/rms_norm/__init__.py b/deeplink_ext/internlm_ops/rms_norm/__init__.py new file mode 100644 index 00000000..6ab1396c --- /dev/null +++ b/deeplink_ext/internlm_ops/rms_norm/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) 2024, DeepLink. + +try: + from .deeplink import DeepLinkRMSNorm, DeepLinkRMSNormWithNormalizedShape +except: + print( + "[deeplink_ext] rms_norm is not implemented in diopi. Falling back to the slower implementation.\n", + end="", + ) + from .fallback import RMSNorm as DeepLinkRMSNorm +from . import fallback diff --git a/deeplink_ext/internlm_ops/rms_norm/deeplink.py b/deeplink_ext/internlm_ops/rms_norm/deeplink.py new file mode 100644 index 00000000..b0f58d7b --- /dev/null +++ b/deeplink_ext/internlm_ops/rms_norm/deeplink.py @@ -0,0 +1,92 @@ +# Copyright (c) 2024, DeepLink. + +import torch +import deeplink_ext.cpp_extensions as ext + +assert hasattr(ext, "rms_norm") + + +# 定义自定义的 autograd 函数 +class _DeepLinkRMSNormFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, hidden_states, weight, bias, eps): + output, inv_rms = ext.rms_norm(hidden_states, None, weight, bias, eps) + + ctx.save_for_backward(hidden_states, inv_rms, weight, bias, torch.tensor(eps)) + return output + + @staticmethod + def backward(ctx, grad_output): + hidden_states, inv_rms, weight, bias, eps_tensor = ctx.saved_tensors + eps = eps_tensor.item() + grad_input, grad_weight, grad_bias = ext.rms_norm_backward( + hidden_states, grad_output, inv_rms, None, weight, bias, eps + ) + return grad_input, grad_weight, grad_bias, None + + +class _DeepLinkRMSNormFunctionWithNormalizedShape(torch.autograd.Function): + @staticmethod + def forward(ctx, hidden_states, weight, bias, eps, normalized_shape): + output, inv_rms = ext.rms_norm( + hidden_states.float(), normalized_shape, weight.float(), bias.float(), eps + ) + output = output.half() + inv_rms = inv_rms.half() + ctx.save_for_backward(hidden_states, inv_rms, weight, bias, torch.tensor(eps)) + hidden_states = hidden_states.half() + weight = weight.half() + bias = bias.half() + ctx.intermediate_results = normalized_shape + return output + + @staticmethod + def backward(ctx, grad_output): + hidden_states, inv_rms, weight, bias, eps_tensor = ctx.saved_tensors + eps = eps_tensor.item() + normalized_shape = ctx.intermediate_results + hidden_states = hidden_states.float() + inv_rms = inv_rms.float() + weight = weight.float() + bias = bias.float() + grad_output = grad_output.float() + grad_input, grad_weight, grad_bias = ext.rms_norm_backward( + hidden_states, grad_output, inv_rms, normalized_shape, weight, bias, eps + ) + grad_output = grad_output.half() + hidden_states = hidden_states.half() + inv_rms = inv_rms.half() + weight = weight.half() + bias = bias.half() + return grad_input, grad_weight, grad_bias, None, None + + +# 定义一个 nn.Module 包裹这个自定义函数 +class DeepLinkRMSNorm(torch.nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(hidden_size)) + self.bias = torch.zeros(hidden_size).cuda() + self.variance_epsilon = eps + + def forward(self, hidden_states): + return _DeepLinkRMSNormFunction.apply( + hidden_states, self.weight, self.bias, self.variance_epsilon + ) + + +class DeepLinkRMSNormWithNormalizedShape(torch.nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(hidden_size)) + self.bias = torch.zeros(hidden_size).cuda() + self.variance_epsilon = eps + + def forward(self, hidden_states): + return _DeepLinkRMSNormFunctionWithNormalizedShape.apply( + hidden_states, + self.weight, + self.bias, + self.variance_epsilon, + self.weight.size(), + ) diff --git a/deeplink_ext/internlm_ops/rms_norm/fallback/__init__.py b/deeplink_ext/internlm_ops/rms_norm/fallback/__init__.py new file mode 100644 index 00000000..3ad8f243 --- /dev/null +++ b/deeplink_ext/internlm_ops/rms_norm/fallback/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) 2024, DeepLink. + +from .fallback import RMSNorm diff --git a/deeplink_ext/internlm_ops/rms_norm/fallback/fallback.py b/deeplink_ext/internlm_ops/rms_norm/fallback/fallback.py new file mode 100644 index 00000000..cb11a2c4 --- /dev/null +++ b/deeplink_ext/internlm_ops/rms_norm/fallback/fallback.py @@ -0,0 +1,24 @@ +# Copyright (c) 2024, DeepLink. + +import torch + + +# RMSNorm fallback from InternLM +class RMSNorm(torch.nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + InternLMRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states diff --git a/deeplink_ext/internlm_ops/rotary/__init__.py b/deeplink_ext/internlm_ops/rotary/__init__.py new file mode 100644 index 00000000..2ebce250 --- /dev/null +++ b/deeplink_ext/internlm_ops/rotary/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2024, DeepLink. + +try: + from .deeplink import DeepLinkApplyRotaryEmb, DeepLinkApplyRotaryEmbQKV_ +except: + print( + "[deeplink_ext] rotary is not implemented in diopi. Falling back to the slower implementation.\n", + end="", + ) + from .fallback import ( + ApplyRotaryEmb as DeepLinkApplyRotaryEmb, + ApplyRotaryEmbQKV_ as DeepLinkApplyRotaryEmbQKV_, + ) +from . import fallback diff --git a/deeplink_ext/internlm_ops/rotary/deeplink.py b/deeplink_ext/internlm_ops/rotary/deeplink.py new file mode 100644 index 00000000..2f41014c --- /dev/null +++ b/deeplink_ext/internlm_ops/rotary/deeplink.py @@ -0,0 +1,170 @@ +# Copyright (c) 2024, DeepLink. + +import torch +from einops import rearrange +import deeplink_ext.cpp_extensions as ext + +assert hasattr(ext, "apply_rotary") + + +class DeepLinkApplyRotaryEmbQKV_(torch.autograd.Function): + @staticmethod + def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False): + batch, seqlen, three, nheads, headdim = qkv.shape + assert three == 3 + rotary_seqlen, rotary_dim = cos.shape + rotary_dim *= 2 + assert rotary_dim <= headdim + assert seqlen <= rotary_seqlen + cos_k = cos if cos_k is None else cos_k + sin_k = sin if sin_k is None else sin_k + assert ( + sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2) + ) + q_ro = qkv[:, :, 0, :, :rotary_dim] + ext.apply_rotary( + q_ro, + q_ro, + rearrange(cos[:seqlen], "s d -> s 1 d"), + rearrange(sin[:seqlen], "s d -> s 1 d"), + False, + False, + ) + k_ro = qkv[:, :, 1, :, :rotary_dim] + ext.apply_rotary( + k_ro, + k_ro, + rearrange(cos[:seqlen], "s d -> s 1 d"), + rearrange(sin[:seqlen], "s d -> s 1 d"), + False, + False, + ) + ctx.save_for_backward(cos, sin, cos_k, sin_k) + ctx.interleaved = interleaved + return qkv + + @staticmethod + def backward(ctx, dqkv): + cos, sin, cos_k, sin_k = ctx.saved_tensors + interleaved = ctx.interleaved + _, seqlen, _, _, headdim = dqkv.shape + rotary_dim = cos.shape[-1] + rotary_dim *= 2 + dq_ro = dqkv[:, :, 0, :, :rotary_dim] + ext.apply_rotary( + dq_ro, + dq_ro, + rearrange(cos[:seqlen], "s d -> s 1 d"), + rearrange(sin[:seqlen], "s d -> s 1 d"), + True, + False, + ) + dk_ro = dqkv[:, :, 1, :, :rotary_dim] + ext.apply_rotary( + dk_ro, + dk_ro, + rearrange(cos[:seqlen], "s d -> s 1 d"), + rearrange(sin[:seqlen], "s d -> s 1 d"), + True, + False, + ) + return dqkv, None, None, None, None, None + + +class DeepLinkApplyRotaryEmb(torch.autograd.Function): + @staticmethod + def forward(ctx, x, cos, sin, interleaved=False, inplace=False): + """ + x: (batch_size, seqlen, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead + of 1st half and 2nd half (GPT-NeoX style). + rotary_dim must be <= headdim + Apply rotary embedding to the first rotary_dim of x. + """ + batch, seqlen, nheads, headdim = x.shape + rotary_seqlen, rotary_dim = cos.shape + rotary_dim *= 2 + assert rotary_dim <= headdim + assert seqlen <= rotary_seqlen + assert sin.shape == (rotary_seqlen, rotary_dim // 2) + x_ro = x[..., :rotary_dim] + x1, x2 = ( + x_ro.chunk(2, dim=-1) + if not interleaved + else (x_ro[..., ::2], x_ro[..., 1::2]) + ) + out = torch.empty_like(x) if not inplace else x + out_ro = out[..., :rotary_dim] + + if inplace: + ext.apply_rotary( + out_ro, + x_ro, + rearrange(cos[:seqlen], "s d -> s 1 d"), + rearrange(sin[:seqlen], "s d -> s 1 d"), + False, + False, + ) + else: + ext.apply_rotary( + out_ro, + x_ro, + rearrange(cos[:seqlen], "s d -> s 1 d"), + rearrange(sin[:seqlen], "s d -> s 1 d"), + False, + False, + ) + + if not inplace and rotary_dim < headdim: + out[..., rotary_dim:].copy_(x[..., rotary_dim:]) + ctx.save_for_backward(cos, sin) + ctx.interleaved = interleaved + ctx.inplace = inplace + return out if not inplace else x + + @staticmethod + def backward(ctx, do): + cos, sin = ctx.saved_tensors + _, seqlen, _, headdim = do.shape + rotary_dim = cos.shape[-1] + rotary_dim *= 2 + inplace = ctx.inplace + do_ro = do[..., :rotary_dim] + do1, do2 = ( + do_ro.chunk(2, dim=-1) + if not ctx.interleaved + else (do_ro[..., ::2], do_ro[..., 1::2]) + ) + dx = torch.empty_like(do) if not inplace else do + if inplace: + dx1, dx2 = do1, do2 + else: + dx_ro = dx[..., :rotary_dim] + dx1, dx2 = ( + dx_ro.chunk(2, dim=-1) + if not ctx.interleaved + else (dx_ro[..., ::2], dx_ro[..., 1::2]) + ) + if inplace: + ext.apply_rotary( + do_ro, + do_ro, + rearrange(cos[:seqlen], "s d -> s 1 d"), + rearrange(sin[:seqlen], "s d -> s 1 d"), + True, + False, + ) + else: + ext.apply_rotary( + dx_ro, + do_ro, + rearrange(cos[:seqlen], "s d -> s 1 d"), + rearrange(sin[:seqlen], "s d -> s 1 d"), + True, + False, + ) + + if not inplace and rotary_dim < headdim: + dx[..., rotary_dim:].copy_(do[..., rotary_dim:]) + return dx, None, None, None, None diff --git a/deeplink_ext/internlm_ops/rotary/fallback/__init__.py b/deeplink_ext/internlm_ops/rotary/fallback/__init__.py new file mode 100644 index 00000000..722045da --- /dev/null +++ b/deeplink_ext/internlm_ops/rotary/fallback/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) 2024, DeepLink. + +from .fallback import ApplyRotaryEmb, ApplyRotaryEmbQKV_ diff --git a/deeplink_ext/internlm_ops/rotary/fallback/fallback.py b/deeplink_ext/internlm_ops/rotary/fallback/fallback.py new file mode 100644 index 00000000..bc07d715 --- /dev/null +++ b/deeplink_ext/internlm_ops/rotary/fallback/fallback.py @@ -0,0 +1,191 @@ +# Copyright (c) 2024, DeepLink. + +import torch +from einops import rearrange +import deeplink_ext.cpp_extensions as ext + + +# Rotary_emb +# torch 绕过实现函数 +def apply_rotary(x1, x2, cos, sin, conj): + data_dtype = x1.dtype + x1 = x1.to(torch.float32) + x2 = x2.to(torch.float32) + cos = cos.to(torch.float32) + sin = sin.to(torch.float32) + if not conj: + out1 = x1 * cos - x2 * sin + out2 = x1 * sin + x2 * cos + else: + out1 = x1 * cos + x2 * sin + out2 = -x1 * sin + x2 * cos + out1 = out1.to(data_dtype) + out2 = out2.to(data_dtype) + return out1, out2 + + +class ApplyRotaryEmbQKV_(torch.autograd.Function): + @staticmethod + def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False): + """ + qkv: (batch_size, seqlen, 3, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) + cos_k, sin_k: (seqlen, rotary_dim / 2), optional + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of + 1st half and 2nd half (GPT-NeoX style). + rotary_dim must be <= headdim + Apply rotary embedding *inplace* to the first rotary_dim of q and k. + """ + batch, seqlen, three, nheads, headdim = qkv.shape + assert three == 3 + rotary_seqlen, rotary_dim = cos.shape + rotary_dim *= 2 + assert rotary_dim <= headdim + assert seqlen <= rotary_seqlen + cos_k = cos if cos_k is None else cos_k + sin_k = sin if sin_k is None else sin_k + assert ( + sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2) + ) + q_ro = qkv[:, :, 0, :, :rotary_dim] + q1, q2 = ( + q_ro.chunk(2, dim=-1) + if not interleaved + else (q_ro[..., ::2], q_ro[..., 1::2]) + ) + q1, q2 = apply_rotary( + q1, + q2, + rearrange(cos[:seqlen], "s d -> s 1 d"), + rearrange(sin[:seqlen], "s d -> s 1 d"), + False, + ) + qkv[:, :, 0, :, :rotary_dim] = torch.cat((q1, q2), dim=-1) + k_ro = qkv[:, :, 1, :, :rotary_dim] + k1, k2 = ( + k_ro.chunk(2, dim=-1) + if not interleaved + else (k_ro[..., ::2], k_ro[..., 1::2]) + ) + k1, k2 = apply_rotary( + k1, + k2, + rearrange(cos[:seqlen], "s d -> s 1 d"), + rearrange(sin[:seqlen], "s d -> s 1 d"), + False, + ) + qkv[:, :, 1, :, :rotary_dim] = torch.cat((k1, k2), dim=-1) + ctx.save_for_backward(cos, sin, cos_k, sin_k) + ctx.interleaved = interleaved + return qkv + + @staticmethod + def backward(ctx, dqkv): + cos, sin, cos_k, sin_k = ctx.saved_tensors + _, seqlen, _, _, headdim = dqkv.shape + rotary_dim = cos.shape[-1] + rotary_dim *= 2 + dq_ro = dqkv[:, :, 0, :, :rotary_dim] + dq1, dq2 = ( + dq_ro.chunk(2, dim=-1) + if not ctx.interleaved + else (dq_ro[..., ::2], dq_ro[..., 1::2]) + ) + dq1, dq2 = apply_rotary( + dq1, + dq2, + rearrange(cos[:seqlen], "s d -> s 1 d"), + rearrange(sin[:seqlen], "s d -> s 1 d"), + True, + ) + dqkv[:, :, 0, :, :rotary_dim] = torch.cat((dq1, dq2), dim=-1) + dk_ro = dqkv[:, :, 1, :, :rotary_dim] + dk1, dk2 = ( + dk_ro.chunk(2, dim=-1) + if not ctx.interleaved + else (dk_ro[..., ::2], dk_ro[..., 1::2]) + ) + dk1, dk2 = apply_rotary( + dk1, + dk2, + rearrange(cos[:seqlen], "s d -> s 1 d"), + rearrange(sin[:seqlen], "s d -> s 1 d"), + True, + ) + dqkv[:, :, 1, :, :rotary_dim] = torch.cat((dk1, dk2), dim=-1) + return dqkv, None, None, None, None, None + + +class ApplyRotaryEmb(torch.autograd.Function): + @staticmethod + def forward(ctx, x, cos, sin, interleaved=False): + """ + x: (batch_size, seqlen, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead + of 1st half and 2nd half (GPT-NeoX style). + rotary_dim must be <= headdim + Apply rotary embedding to the first rotary_dim of x. + """ + batch, seqlen, nheads, headdim = x.shape + rotary_seqlen, rotary_dim = cos.shape + rotary_dim *= 2 + assert rotary_dim <= headdim + assert seqlen <= rotary_seqlen + assert sin.shape == (rotary_seqlen, rotary_dim // 2) + x_ro = x[..., :rotary_dim] + x1, x2 = ( + x_ro.chunk(2, dim=-1) + if not interleaved + else (x_ro[..., ::2], x_ro[..., 1::2]) + ) + out = torch.empty_like(x) + out_ro = out[..., :rotary_dim] + + ext.apply_rotary( + out_ro, + x_ro, + rearrange(cos[:seqlen], "s d -> s 1 d"), + rearrange(sin[:seqlen], "s d -> s 1 d"), + False, + False, + ) + + if rotary_dim < headdim: + out[..., rotary_dim:].copy_(x[..., rotary_dim:]) + ctx.save_for_backward(cos, sin) + ctx.interleaved = interleaved + return out + + @staticmethod + def backward(ctx, do): + cos, sin = ctx.saved_tensors + _, seqlen, _, headdim = do.shape + rotary_dim = cos.shape[-1] + rotary_dim *= 2 + do_ro = do[..., :rotary_dim] + do1, do2 = ( + do_ro.chunk(2, dim=-1) + if not ctx.interleaved + else (do_ro[..., ::2], do_ro[..., 1::2]) + ) + dx = torch.empty_like(do) + + dx_ro = dx[..., :rotary_dim] + dx1, dx2 = ( + dx_ro.chunk(2, dim=-1) + if not ctx.interleaved + else (dx_ro[..., ::2], dx_ro[..., 1::2]) + ) + ext.apply_rotary( + dx_ro, + do_ro, + rearrange(cos[:seqlen], "s d -> s 1 d"), + rearrange(sin[:seqlen], "s d -> s 1 d"), + True, + False, + ) + + if rotary_dim < headdim: + dx[..., rotary_dim:].copy_(do[..., rotary_dim:]) + return dx, None, None, None, None diff --git a/deeplink_ext/patch_internlm.py b/deeplink_ext/patch_internlm.py new file mode 100644 index 00000000..85808fd7 --- /dev/null +++ b/deeplink_ext/patch_internlm.py @@ -0,0 +1,80 @@ +# Copyright (c) 2024, DeepLink. + + +def _patch_internlm(): + import importlib.util + import os + import sys + import unittest.mock as mock + import deeplink_ext.internlm_ops as ext + + def _find_or_mock_module(module_name): + module_spec = importlib.util.find_spec(module_name) + if module_spec is None: + sys.modules[module_name] = mock.Mock() + + def _find_flash_attn(): + flash_attn_spec = importlib.util.find_spec("flash_attn") + if flash_attn_spec is None: + internlm_spec = importlib.util.find_spec("internlm") + assert internlm_spec is not None + assert internlm_spec.submodule_search_locations is not None + sys.path.append( + os.path.abspath( + os.path.join( + internlm_spec.submodule_search_locations[0], + "../third_party/flash-attention", + ) + ) + ) + + def _patch_flash_attn(): + import flash_attn.losses.cross_entropy # type: ignore + import torch.nn + + def CrossEntropyLossProxy(reduction, **_): + return torch.nn.CrossEntropyLoss(reduction=reduction) + + flash_attn.losses.cross_entropy.CrossEntropyLoss = CrossEntropyLossProxy + + import flash_attn.modules.mha # type: ignore + + flash_attn.modules.mha.SelfAttention = ext.mha.DeepLinkSelfAttention + flash_attn.modules.mha.FlashSelfAttention = ext.mha.DeepLinkSelfAttention + flash_attn.modules.mha.CrossAttention = ext.mha.DeepLinkCrossAttention + flash_attn.modules.mha.FlashCrossAttention = ext.mha.DeepLinkCrossAttention + + def _patch_ops(): + import internlm.model.embedding # type: ignore + + # TODO(lljbash,gongqiwei): implement a module aligned with rotary_emb + def NotImplementedRotaryEnb(*args, **kwargs): + raise NotImplementedError( + "the patch for apply_rotary_emb_qkv_ (requires rotary_emb) has not been implemented in deeplink_ext yet" + ) + + internlm.model.embedding.apply_rotary_emb_qkv_ = NotImplementedRotaryEnb + internlm.model.embedding.legacy_apply_rotary_embed = ( + ext.rotary.DeepLinkApplyRotaryEmb.apply + ) + internlm.model.embedding.legacy_apply_rotary_embed_qkv = ( + ext.rotary.DeepLinkApplyRotaryEmbQKV_.apply + ) + + import internlm.model.norm # type: ignore + + internlm.model.norm.RMSNormTorch = ( + ext.rms_norm.DeepLinkRMSNormWithNormalizedShape + ) + + _find_or_mock_module("rotary_emb") + _find_or_mock_module("fused_dense_lib") + _find_or_mock_module("xentropy_cuda_lib") + _find_or_mock_module("flash_attn_cuda") + _find_flash_attn() + _patch_flash_attn() + _patch_ops() + print("[deeplink_ext] patched diopi implementation of internlm\n", end="") + + +_patch_internlm() diff --git a/deeplink_ext/patch_lightllm.py b/deeplink_ext/patch_lightllm.py new file mode 100644 index 00000000..d371870a --- /dev/null +++ b/deeplink_ext/patch_lightllm.py @@ -0,0 +1,74 @@ +# Copyright (c) 2024, DeepLink. + + +def _patch_lightllm(): + import os + import deeplink_ext.cpp_extensions as ext + import lightllm.common.basemodel.triton_kernel.destindex_copy_kv as destindex_copy_kv_pack # type: ignore + import lightllm.common.basemodel.triton_kernel.apply_penalty as apply_penalty_pack # type: ignore + import lightllm.models.llama.triton_kernel.context_flashattention_nopad as context_attention_pack # type: ignore + import lightllm.models.llama.triton_kernel.token_attention_nopad_att1 as token_attention_pack # type: ignore + import lightllm.models.llama.triton_kernel.token_attention_softmax_and_reducev as token_attention_softmax_reducev_pack # type: ignore + import lightllm.models.llama.triton_kernel.rmsnorm as rms_norm_pack # type: ignore + import lightllm.models.llama.triton_kernel.rotary_emb as rotary_emb_pack # type: ignore + + DEFAULT_PATCH_LIST = [ + "dest_index_copy_kv", + "apply_penalty", + "context_attention_inference", + "token_attention_inference", + "token_softmax_reducev_inference", + "rms_norm_lightllm", + "rotary_emb", + ] + PATCH_LIST_ENV_NAME = "DEEPLINKEXT_LIGHTLLM_PATCH_LIST" + patch_list_env = os.environ.get(PATCH_LIST_ENV_NAME) + use_custom_patch_list = patch_list_env is not None + patch_list = ( + patch_list_env.split(",") if use_custom_patch_list else DEFAULT_PATCH_LIST + ) + if use_custom_patch_list: + print(f"[deeplink_ext] use custom lightllm patch list: {patch_list}\n", end="") + + def try_patch(op: str): + def patch_dest_index_copy_kv(): + destindex_copy_kv_pack.destindex_copy_kv = ext.dest_index_copy_kv + + def patch_apply_penalty(): + apply_penalty_pack.apply_penalty = ext.apply_penalty + + def patch_context_attention_inference(): + context_attention_pack.context_attention_fwd = ( + ext.context_attention_inference + ) + + def patch_token_attention_inference(): + token_attention_pack.token_att_fwd = ext.token_attention_inference + + def patch_token_softmax_reducev_inference(): + token_attention_softmax_reducev_pack.token_softmax_reducev_fwd = ( + ext.token_softmax_reducev_inference + ) + + def patch_rms_norm_lightllm(): + rms_norm_pack.rmsnorm_forward = ext.rms_norm_lightllm + + def patch_rotary_emb(): + rotary_emb_pack.rotary_emb_fwd = ext.rotary_emb + + try: + locals()[f"patch_{op}"]() + print(f"[deeplink_ext] patched diopi implementation of {op}\n", end="") + except KeyError: + print( + f"[deeplink_ext] unknow op: {op}, supported ops: {DEFAULT_PATCH_LIST}\n", + end="", + ) + except AttributeError: + print(f"[deeplink_ext] op {op} is not implemented in diopi\n", end="") + + for op in patch_list: + try_patch(op) + + +_patch_lightllm() diff --git a/ext_apply/internlm/RMSNorm.py b/ext_apply/internlm/RMSNorm.py deleted file mode 100644 index d2e3453c..00000000 --- a/ext_apply/internlm/RMSNorm.py +++ /dev/null @@ -1,125 +0,0 @@ -import torch -from torch import nn -import torch_dipu -import dipu_ext.ext_ as deeplink_ext -import copy - - -class InternLMRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - InternLMRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - - # convert into half-precision if necessary - if self.weight.dtype in [torch.float16, torch.bfloat16]: - hidden_states = hidden_states.to(self.weight.dtype) - - return self.weight * hidden_states - - -# 定义自定义的autograd函数 -class _DeepLinkRMSNormFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, hidden_states, weight, bias, eps): - output, inv_rms = deeplink_ext.rms_norm( - hidden_states, - None, - weight, - bias, - eps - ) - - ctx.save_for_backward(hidden_states, inv_rms, weight, bias, torch.tensor(eps)) - return output - - @staticmethod - def backward(ctx, grad_output): - hidden_states, inv_rms, weight, bias, eps_tensor = ctx.saved_tensors - eps = eps_tensor.item() - grad_input, grad_weight, grad_bias = deeplink_ext.rms_norm_backward( - hidden_states, - grad_output, - inv_rms, - None, - weight, - bias, - eps - ) - return grad_input, grad_weight, grad_bias, None - -class _DeepLinkRMSNormFunction_WithNormalizedShape(torch.autograd.Function): - @staticmethod - def forward(ctx, hidden_states, weight, bias, eps, normalized_shape): - output, inv_rms = deeplink_ext.rms_norm( - hidden_states.float(), - normalized_shape, - weight.float(), - bias.float(), - eps - ) - output = output.half() - inv_rms = inv_rms.half() - ctx.save_for_backward(hidden_states, inv_rms, weight, bias, torch.tensor(eps)) - hidden_states = hidden_states.half() - weight = weight.half() - bias = bias.half() - ctx.intermediate_results = normalized_shape - return output - - @staticmethod - def backward(ctx, grad_output): - hidden_states, inv_rms, weight, bias, eps_tensor = ctx.saved_tensors - eps = eps_tensor.item() - normalized_shape = ctx.intermediate_results - hidden_states = hidden_states.float() - inv_rms = inv_rms.float() - weight = weight.float() - bias = bias.float() - grad_output = grad_output.float() - grad_input, grad_weight, grad_bias = deeplink_ext.rms_norm_backward( - hidden_states, - grad_output, - inv_rms, - normalized_shape, - weight, - bias, - eps - ) - grad_output = grad_output.half() - hidden_states = hidden_states.half() - inv_rms = inv_rms.half() - weight = weight.half() - bias = bias.half() - return grad_input, grad_weight, grad_bias, None, None - - -# 定义一个nn.Module包裹这个自定义函数 -class DeepLinkRMSNorm(nn.Module): - - def __init__(self, hidden_size, eps=1e-6): - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.bias = torch.zeros(hidden_size).cuda() - self.variance_epsilon = eps - - def forward(self, hidden_states): - return _DeepLinkRMSNormFunction.apply(hidden_states, self.weight, self.bias, self.variance_epsilon) - -class DeepLinkRMSNorm_WithNormalizedShape(nn.Module): - - def __init__(self, hidden_size, eps=1e-6): - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.bias = torch.zeros(hidden_size).cuda() - self.variance_epsilon = eps - - def forward(self, hidden_states): - return _DeepLinkRMSNormFunction_WithNormalizedShape.apply(hidden_states, self.weight, self.bias, self.variance_epsilon, self.weight.size()) diff --git a/ext_apply/internlm/ext_apply_rotary.py b/ext_apply/internlm/ext_apply_rotary.py deleted file mode 100644 index ae6b18da..00000000 --- a/ext_apply/internlm/ext_apply_rotary.py +++ /dev/null @@ -1,344 +0,0 @@ -import torch -from einops import rearrange - - -# Rotary_emb -# torch绕过实现函数 -def torch_apply_rotary(x1, x2, cos, sin, conj): - data_dtype = x1.dtype - x1 = x1.to(torch.float32) - x2 = x2.to(torch.float32) - cos = cos.to(torch.float32) - sin = sin.to(torch.float32) - if not conj: - out1 = x1 * cos - x2 * sin - out2 = x1 * sin + x2 * cos - else: - out1 = x1 * cos + x2 * sin - out2 = -x1 * sin + x2 * cos - out1 = out1.to(data_dtype) - out2 = out2.to(data_dtype) - return out1, out2 - - -class TorchApplyRotaryEmbQKV_(torch.autograd.Function): - @staticmethod - def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False): - """ - qkv: (batch_size, seqlen, 3, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) - cos_k, sin_k: (seqlen, rotary_dim / 2), optional - interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of - 1st half and 2nd half (GPT-NeoX style). - rotary_dim must be <= headdim - Apply rotary embedding *inplace* to the first rotary_dim of q and k. - """ - batch, seqlen, three, nheads, headdim = qkv.shape - assert three == 3 - rotary_seqlen, rotary_dim = cos.shape - rotary_dim *= 2 - assert rotary_dim <= headdim - assert seqlen <= rotary_seqlen - cos_k = cos if cos_k is None else cos_k - sin_k = sin if sin_k is None else sin_k - assert ( - sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2) - ) - q_ro = qkv[:, :, 0, :, :rotary_dim] - q1, q2 = ( - q_ro.chunk(2, dim=-1) - if not interleaved - else (q_ro[..., ::2], q_ro[..., 1::2]) - ) - q1, q2 = torch_apply_rotary( - q1, - q2, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - False, - ) - qkv[:, :, 0, :, :rotary_dim] = torch.cat((q1, q2), dim=-1) - k_ro = qkv[:, :, 1, :, :rotary_dim] - k1, k2 = ( - k_ro.chunk(2, dim=-1) - if not interleaved - else (k_ro[..., ::2], k_ro[..., 1::2]) - ) - k1, k2 = torch_apply_rotary( - k1, - k2, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - False, - ) - qkv[:, :, 1, :, :rotary_dim] = torch.cat((k1, k2), dim=-1) - ctx.save_for_backward(cos, sin, cos_k, sin_k) - ctx.interleaved = interleaved - return qkv - - @staticmethod - def backward(ctx, dqkv): - cos, sin, cos_k, sin_k = ctx.saved_tensors - _, seqlen, _, _, headdim = dqkv.shape - rotary_dim = cos.shape[-1] - rotary_dim *= 2 - dq_ro = dqkv[:, :, 0, :, :rotary_dim] - dq1, dq2 = ( - dq_ro.chunk(2, dim=-1) - if not ctx.interleaved - else (dq_ro[..., ::2], dq_ro[..., 1::2]) - ) - dq1, dq2 = torch_apply_rotary( - dq1, - dq2, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - True, - ) - dqkv[:, :, 0, :, :rotary_dim] = torch.cat((dq1, dq2), dim=-1) - dk_ro = dqkv[:, :, 1, :, :rotary_dim] - dk1, dk2 = ( - dk_ro.chunk(2, dim=-1) - if not ctx.interleaved - else (dk_ro[..., ::2], dk_ro[..., 1::2]) - ) - dk1, dk2 = torch_apply_rotary( - dk1, - dk2, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - True, - ) - dqkv[:, :, 1, :, :rotary_dim] = torch.cat((dk1, dk2), dim=-1) - return dqkv, None, None, None, None, None - - -class TorchApplyRotaryEmb(torch.autograd.Function): - - @staticmethod - def forward(ctx, x, cos, sin, interleaved=False): - """ - x: (batch_size, seqlen, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) - interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead - of 1st half and 2nd half (GPT-NeoX style). - rotary_dim must be <= headdim - Apply rotary embedding to the first rotary_dim of x. - """ - batch, seqlen, nheads, headdim = x.shape - rotary_seqlen, rotary_dim = cos.shape - rotary_dim *= 2 - assert rotary_dim <= headdim - assert seqlen <= rotary_seqlen - assert sin.shape == (rotary_seqlen, rotary_dim // 2) - x_ro = x[..., :rotary_dim] - x1, x2 = x_ro.chunk(2, dim=-1) if not interleaved else (x_ro[..., ::2], x_ro[..., 1::2]) - out = torch.empty_like(x) - out_ro = out[..., :rotary_dim] - - dipu_ext.ext_.apply_rotary( - out_ro, - x_ro, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - False, - False - ) - - if rotary_dim < headdim: - out[..., rotary_dim:].copy_(x[..., rotary_dim:]) - ctx.save_for_backward(cos, sin) - ctx.interleaved = interleaved - return out - - @staticmethod - def backward(ctx, do): - cos, sin = ctx.saved_tensors - _, seqlen, _, headdim = do.shape - rotary_dim = cos.shape[-1] - rotary_dim *= 2 - do_ro = do[..., :rotary_dim] - do1, do2 = (do_ro.chunk(2, dim=-1) if not ctx.interleaved - else (do_ro[..., ::2], do_ro[..., 1::2])) - dx = torch.empty_like(do) - - dx_ro = dx[..., :rotary_dim] - dx1, dx2 = (dx_ro.chunk(2, dim=-1) if not ctx.interleaved - else (dx_ro[..., ::2], dx_ro[..., 1::2])) - dipu_ext.ext_.apply_rotary( - dx_ro, - do_ro, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - True, - False - ) - - if rotary_dim < headdim: - dx[..., rotary_dim:].copy_(do[..., rotary_dim:]) - return dx, None, None, None, None - - -try: - import dipu_ext.ext_ - - print("using ext apply_rotary") - - class DeepLinkApplyRotaryEmbQKV_(torch.autograd.Function): - @staticmethod - def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False): - batch, seqlen, three, nheads, headdim = qkv.shape - assert three == 3 - rotary_seqlen, rotary_dim = cos.shape - rotary_dim *= 2 - assert rotary_dim <= headdim - assert seqlen <= rotary_seqlen - cos_k = cos if cos_k is None else cos_k - sin_k = sin if sin_k is None else sin_k - assert ( - sin.shape - == cos_k.shape - == sin_k.shape - == (rotary_seqlen, rotary_dim // 2) - ) - q_ro = qkv[:, :, 0, :, :rotary_dim] - dipu_ext.ext_.apply_rotary( - q_ro, - q_ro, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - False, - False - ) - k_ro = qkv[:, :, 1, :, :rotary_dim] - dipu_ext.ext_.apply_rotary( - k_ro, - k_ro, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - False, - False - ) - ctx.save_for_backward(cos, sin, cos_k, sin_k) - ctx.interleaved = interleaved - return qkv - - @staticmethod - def backward(ctx, dqkv): - cos, sin, cos_k, sin_k = ctx.saved_tensors - interleaved = ctx.interleaved - _, seqlen, _, _, headdim = dqkv.shape - rotary_dim = cos.shape[-1] - rotary_dim *= 2 - dq_ro = dqkv[:, :, 0, :, :rotary_dim] - dipu_ext.ext_.apply_rotary( - dq_ro, - dq_ro, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - True, - False - ) - dk_ro = dqkv[:, :, 1, :, :rotary_dim] - dipu_ext.ext_.apply_rotary( - dk_ro, - dk_ro, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - True, - False - ) - return dqkv, None, None, None, None, None - - class DeepLinkApplyRotaryEmb(torch.autograd.Function): - - @staticmethod - def forward(ctx, x, cos, sin, interleaved=False, inplace=False): - """ - x: (batch_size, seqlen, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) - interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead - of 1st half and 2nd half (GPT-NeoX style). - rotary_dim must be <= headdim - Apply rotary embedding to the first rotary_dim of x. - """ - batch, seqlen, nheads, headdim = x.shape - rotary_seqlen, rotary_dim = cos.shape - rotary_dim *= 2 - assert rotary_dim <= headdim - assert seqlen <= rotary_seqlen - assert sin.shape == (rotary_seqlen, rotary_dim // 2) - x_ro = x[..., :rotary_dim] - x1, x2 = x_ro.chunk(2, dim=-1) if not interleaved else (x_ro[..., ::2], x_ro[..., 1::2]) - out = torch.empty_like(x) if not inplace else x - out_ro = out[..., :rotary_dim] - - if inplace: - dipu_ext.ext_.apply_rotary( - out_ro, - x_ro, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - False, - False - ) - else: - dipu_ext.ext_.apply_rotary( - out_ro, - x_ro, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - False, - False - ) - - if not inplace and rotary_dim < headdim: - out[..., rotary_dim:].copy_(x[..., rotary_dim:]) - ctx.save_for_backward(cos, sin) - ctx.interleaved = interleaved - ctx.inplace = inplace - return out if not inplace else x - - @staticmethod - def backward(ctx, do): - cos, sin = ctx.saved_tensors - _, seqlen, _, headdim = do.shape - rotary_dim = cos.shape[-1] - rotary_dim *= 2 - inplace = ctx.inplace - do_ro = do[..., :rotary_dim] - do1, do2 = (do_ro.chunk(2, dim=-1) if not ctx.interleaved - else (do_ro[..., ::2], do_ro[..., 1::2])) - dx = torch.empty_like(do) if not inplace else do - if inplace: - dx1, dx2 = do1, do2 - else: - dx_ro = dx[..., :rotary_dim] - dx1, dx2 = (dx_ro.chunk(2, dim=-1) if not ctx.interleaved - else (dx_ro[..., ::2], dx_ro[..., 1::2])) - if inplace: - dipu_ext.ext_.apply_rotary( - do_ro, - do_ro, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - True, - False - ) - else: - dipu_ext.ext_.apply_rotary( - dx_ro, - do_ro, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - True, - False - ) - - if not inplace and rotary_dim < headdim: - dx[..., rotary_dim:].copy_(do[..., rotary_dim:]) - return dx, None, None, None, None - -except: - print("NOT using ext apply_rotary") - pass diff --git a/ext_apply/internlm/ext_mha/__init__.py b/ext_apply/internlm/ext_mha/__init__.py deleted file mode 100644 index 783398d8..00000000 --- a/ext_apply/internlm/ext_mha/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright (c) 2023, DeepLink. - -from .mha import DeepLinkSelfAttention, DeepLinkCrossAttention -from . import mha_fallback as fallback diff --git a/ext_apply/lightllm/ext_apply_rotary.py b/ext_apply/lightllm/ext_apply_rotary.py deleted file mode 100644 index 6369cf0a..00000000 --- a/ext_apply/lightllm/ext_apply_rotary.py +++ /dev/null @@ -1,21 +0,0 @@ -import torch -import dipu_ext.ext_ -from einops import rearrange - - -# Rotary_emb -# 本身就是基于pytorch的实现,所以不需要pytorch绕过代码 - -try: - import dipu_ext.ext_ - print("using ext apply_rotary") - def deeplink_rotary_emb(x, cos, sin): - seq_len, h, dim = x.shape - cos = cos.view((seq_len, 1, dim // 2)) - sin = sin.view((seq_len, 1, dim // 2)) - output = torch.empty_like(x) - dipu_ext.ext_.apply_rotary(output, x, cos, sin, False, False) - return output -except: - print("NOT using ext apply_rotary") - pass diff --git a/ext_apply/lightllm/mock_op.py b/ext_apply/lightllm/mock_op.py deleted file mode 100644 index d6a0e3e3..00000000 --- a/ext_apply/lightllm/mock_op.py +++ /dev/null @@ -1,46 +0,0 @@ -import torch -import os -try: - import dipu_ext.ext_ as ext - default_diopi_mock_op_list = ['dest_index_copy_kv', 'token_attention_inference', - 'token_softmax_reducev_inference', 'context_attention_inference', - 'apply_penalty', 'rms_norm_lightllm', 'rotary_emb'] - diopi_mock_op_list = os.environ.get('diopi_mock_op_list').split( - ',') if 'diopi_mock_op_list' in os.environ else default_diopi_mock_op_list - print(f"diopi_mock_op_list:{diopi_mock_op_list}") - if hasattr(ext, 'dest_index_copy_kv') and 'dest_index_copy_kv' in diopi_mock_op_list: - import lightllm.common.basemodel.triton_kernel.destindex_copy_kv as destindex_copy_kv_pack - destindex_copy_kv_pack.destindex_copy_kv = ext.dest_index_copy_kv - print("use diopi_dest_index_copy_kv as destindex_copy_kv") - - if hasattr(ext, 'apply_penalty') and 'apply_penalty' in diopi_mock_op_list: - import lightllm.common.basemodel.triton_kernel.apply_penalty as apply_penalty_pack - apply_penalty_pack.apply_penalty = ext.apply_penalty - print("use diopi_apply_penalty as apply_penalty") - - if hasattr(ext, 'context_attention_inference') and 'context_attention_inference' in diopi_mock_op_list: - import lightllm.models.llama.triton_kernel.context_flashattention_nopad as context_attention_pack - context_attention_pack.context_attention_fwd = ext.context_attention_inference - print("use diopi_context_attention_inference as context_attention_fwd") - - if hasattr(ext, 'token_attention_inference') and 'token_attention_inference' in diopi_mock_op_list: - import lightllm.models.llama.triton_kernel.token_attention_nopad_att1 as token_attention_pack - token_attention_pack.token_att_fwd = ext.token_attention_inference - print("use diopi_token_attention_inference as token_att_fwd") - - if hasattr(ext, 'token_softmax_reducev_inference') and 'token_softmax_reducev_inference' in diopi_mock_op_list: - import lightllm.models.llama.triton_kernel.token_attention_softmax_and_reducev as token_attention_softmax_reducev_pack - token_attention_softmax_reducev_pack.token_softmax_reducev_fwd = ext.token_softmax_reducev_inference - print("use diopi_token_softmax_reducev_inference as token_softmax_reducev_fwd") - - if hasattr(ext, 'rms_norm_lightllm') and 'rms_norm_lightllm' in diopi_mock_op_list: - import lightllm.models.llama.triton_kernel.rmsnorm as rms_norm_pack - rms_norm_pack.rmsnorm_forward = ext.rms_norm_lightllm - print("use diopi_rms_norm as rmsnorm_forward") - - if hasattr(ext, 'rotary_emb') and 'rotary_emb' in diopi_mock_op_list: - import lightllm.models.llama.triton_kernel.rotary_emb as rotary_emb_pack - rotary_emb_pack.rotary_emb_fwd = ext.rotary_emb - print("use diopi_rotary_embedding as rotary_emb_fwd") -except ImportError: - pass diff --git a/setup.py b/setup.py index d93b78d8..0f7fb640 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,10 @@ -from setuptools import setup, Extension +# Copyright (c) 2024, DeepLink. + +from setuptools import find_packages, setup, Extension from torch.utils.cpp_extension import BuildExtension, include_paths, library_paths import glob import os +import subprocess def _getenv_or_die(env_name: str): @@ -11,11 +14,31 @@ def _getenv_or_die(env_name: str): return env +class BuildExtensionWithCompdb(BuildExtension): + def build_extensions(self): + super().build_extensions() + try: + self._gen_compdb() + except Exception as e: + print(f"Failed to generate compile_commands.json: {e}") + + def _gen_compdb(self): + assert self.use_ninja + build_ninja_file = glob.glob("./build/**/build.ninja", recursive=True) + assert len(build_ninja_file) == 1 + with open("build/compile_commands.json", "w") as f: + subprocess.run( + ["ninja", "-f", build_ninja_file[0], "-t", "compdb"], + stdout=f, + check=True, + ) + print("Generated build/compile_commands.json") + + def get_ext(): - ext_name = "dipu_ext.ext_" - os.makedirs('dipu_ext', exist_ok=True) + ext_name = "deeplink_ext.cpp_extensions" # 包含所有算子文件 - op_files = glob.glob("./ext_op/*.cpp") + op_files = glob.glob("./csrc/*.cpp") include_dirs = [] system_include_dirs = include_paths() define_macros = [] @@ -40,9 +63,8 @@ def get_ext(): library_dirs += [dipu_root] libraries += ["torch_dipu"] - extra_compile_args = {"cxx": []} - extra_compile_args["cxx"] = ["-std=c++14"] - extra_compile_args["cxx"] += ["-isystem" + path for path in system_include_dirs] + extra_compile_args = ["-std=c++17"] + extra_compile_args += ["-isystem" + path for path in system_include_dirs] ext_ops = Extension( name=ext_name, # 拓展模块名字 sources=op_files, @@ -57,4 +79,10 @@ def get_ext(): return [ext_ops] -setup(name="dipu_ext", ext_modules=get_ext(), cmdclass={"build_ext": BuildExtension}) +setup( + name="deeplink_ext", + packages=find_packages(exclude=["build", "csrc", "tests"]), + ext_modules=get_ext(), + cmdclass={"build_ext": BuildExtensionWithCompdb}, + install_requires=["einops"], +) diff --git a/test/test_mha_internlm.py b/tests/test_mha_internlm.py similarity index 77% rename from test/test_mha_internlm.py rename to tests/test_mha_internlm.py index e8c67a05..b74ecc47 100644 --- a/test/test_mha_internlm.py +++ b/tests/test_mha_internlm.py @@ -1,8 +1,7 @@ # Copyright (c) 2023, DeepLink. import torch -import torch_dipu -import DeepLinkExt.ext_apply.internlm.ext_mha as ext_mha +import deeplink_ext.internlm_ops.mha as ext def _run_self_attention(self_attn_module: type, qkv_data: torch.Tensor): @@ -29,8 +28,8 @@ def _run_cross_attention( H = 2 D = 8 qkv = torch.randn(B, S, 3, H, D, dtype=torch.float16).cuda() -output_gold, grad_gold = _run_self_attention(ext_mha.fallback.SelfAttention, qkv) -output_ext, grad_ext = _run_self_attention(ext_mha.DeepLinkSelfAttention, qkv) +output_gold, grad_gold = _run_self_attention(ext.fallback.SelfAttention, qkv) +output_ext, grad_ext = _run_self_attention(ext.DeepLinkSelfAttention, qkv) assert torch.allclose(output_gold, output_ext, atol=1e-3) print("SelfAttention forward test pass") assert torch.allclose(grad_gold, grad_ext, atol=2e-3) @@ -39,11 +38,9 @@ def _run_cross_attention( q = qkv[:, :, 0] kv = qkv[:, :, 1:] output_gold, dq_gold, dkv_gold = _run_cross_attention( - ext_mha.fallback.CrossAttention, q, kv -) -output_ext, dq_ext, dkv_ext = _run_cross_attention( - ext_mha.DeepLinkCrossAttention, q, kv + ext.fallback.CrossAttention, q, kv ) +output_ext, dq_ext, dkv_ext = _run_cross_attention(ext.DeepLinkCrossAttention, q, kv) assert torch.allclose(output_gold, output_ext, atol=1e-3) print("CrossAttention forward test pass") assert torch.allclose(dq_gold, dq_ext, atol=2e-3) diff --git a/test/test_rms_internlm.py b/tests/test_rms_internlm.py similarity index 63% rename from test/test_rms_internlm.py rename to tests/test_rms_internlm.py index d476babb..8513c74d 100644 --- a/test/test_rms_internlm.py +++ b/tests/test_rms_internlm.py @@ -1,12 +1,8 @@ +# Copyright (c) 2023, DeepLink. + import torch -import torch_dipu import numpy as np - -from ext_apply.internlm.RMSNorm import ( - InternLMRMSNorm, - DeeplinkRMSNorm, - DeeplinkRMSNorm_WithNormalizedShape, -) +import deeplink_ext.internlm_ops.rms_norm as ext def test_rms_norm(BaseRmsNorm, DeeplinkRmsNorm, rtol=1e-4, atol=1e-3): @@ -30,5 +26,12 @@ def test_rms_norm(BaseRmsNorm, DeeplinkRmsNorm, rtol=1e-4, atol=1e-3): return np.allclose(grad_x_base, grad_x_intern, rtol, atol, True) -print("Test case: normalized_shape == None: grad_inputs closed ? ", test_rms_norm(InternLMRMSNorm, DeeplinkRMSNorm)) -print("Test case: normalized_shape == weight.size(): grad_inputs closed ? ", test_rms_norm(InternLMRMSNorm, DeeplinkRMSNorm_WithNormalizedShape)) + +print( + "Test case: normalized_shape == None: grad_inputs closed ? ", + test_rms_norm(ext.fallback.RMSNorm, ext.DeepLinkRMSNorm), +) +print( + "Test case: normalized_shape == weight.size(): grad_inputs closed ? ", + test_rms_norm(ext.fallback.RMSNorm, ext.DeepLinkRMSNormWithNormalizedShape), +) diff --git a/test/test_rms_lightlm.py b/tests/test_rms_lightlm.py similarity index 59% rename from test/test_rms_lightlm.py rename to tests/test_rms_lightlm.py index 2f880c4d..ea9f66b3 100644 --- a/test/test_rms_lightlm.py +++ b/tests/test_rms_lightlm.py @@ -1,11 +1,7 @@ -import torch - -import dipu_ext.ext_ as deeplink_ext -import torch_dipu -import pdb -# import debugat +# Copyright (c) 2023, DeepLink. -# 假设 deeplink_ext 是一个包含上述 RMS normalization 函数的模块 +import torch +import deeplink_ext.cpp_extensions as ext # 定义输入张量 input = torch.randn(5, 5, requires_grad=True).cuda() @@ -20,29 +16,12 @@ # 归一化的形状通常是输入张量的形状 normalized_shape = torch.tensor([5, 5], dtype=torch.long).cuda() -# pdb.set_trace() - -# 使用 RMS normalization 前向传播 -# while True: - print(input.is_dipu) -output, inv_rms = deeplink_ext.rms_norm( - input, - None, - weight, - bias, - 1e-6 -) +output, inv_rms = ext.rms_norm(input, None, weight, bias, 1e-6) # 使用 RMS normalization 反向传播 -grad_input, grad_weight, grad_bias = deeplink_ext.rms_norm_backward( - input, - grad_output, - inv_rms, - None, - weight, - bias, - 1e-6 +grad_input, grad_weight, grad_bias = ext.rms_norm_backward( + input, grad_output, inv_rms, None, weight, bias, 1e-6 ) print("Output:", output) diff --git a/test/test_rotary_emb_internlm.py b/tests/test_rotary_emb_internlm.py similarity index 70% rename from test/test_rotary_emb_internlm.py rename to tests/test_rotary_emb_internlm.py index 390bc430..7cc4b5e2 100644 --- a/test/test_rotary_emb_internlm.py +++ b/tests/test_rotary_emb_internlm.py @@ -1,27 +1,25 @@ +# Copyright (c) 2023, DeepLink. + import torch -import torch_dipu -from einops import rearrange -import dipu_ext.ext_ -from DeepLinkExt.ext_apply.internlm.ext_apply_rotary import ( - TorchApplyRotaryEmbQKV_, - DeepLinkApplyRotaryEmbQKV_, - TorchApplyRotaryEmb, - DeepLinkApplyRotaryEmb, -) +import deeplink_ext.internlm_ops.rotary as ext -def RotaryEmbTest(func_name): +def RotaryEmbTest(func_name): if func_name == "RotaryEmbQKV": - torch_apply = TorchApplyRotaryEmbQKV_.apply - dipu_apply = DeepLinkApplyRotaryEmbQKV_.apply - input = torch.randn(1, 125, 3, 16, 32, dtype=torch.float16, requires_grad=True).cuda() + torch_apply = ext.fallback.ApplyRotaryEmbQKV_.apply + dipu_apply = ext.DeepLinkApplyRotaryEmbQKV_.apply + input = torch.randn( + 1, 125, 3, 16, 32, dtype=torch.float16, requires_grad=True + ).cuda() elif func_name == "RotaryEmb": - torch_apply = TorchApplyRotaryEmb.apply - dipu_apply = DeepLinkApplyRotaryEmb.apply - input = torch.randn(1, 125, 16, 32, dtype=torch.float16, requires_grad=True).cuda() + torch_apply = ext.fallback.ApplyRotaryEmb.apply + dipu_apply = ext.DeepLinkApplyRotaryEmb.apply + input = torch.randn( + 1, 125, 16, 32, dtype=torch.float16, requires_grad=True + ).cuda() else: - print(f"{func_name} is not supported.") - return False + print(f"{func_name} is not supported.") + return False loss_fn = torch.nn.MSELoss() cos = torch.randn(257, 16, dtype=torch.float16).cuda() @@ -42,9 +40,8 @@ def RotaryEmbTest(func_name): res1 = torch_apply(input, cos, sin, interleaved) res2 = dipu_apply(input1, cos1, sin1, interleaved) else: - print(f"{func_name} is not supported.") - return False - + print(f"{func_name} is not supported.") + return False # 验证前向传播结果 forward_correct = torch.allclose(res1, res2) @@ -78,5 +75,6 @@ def RotaryEmbTest(func_name): ) return False + assert RotaryEmbTest("RotaryEmbQKV") assert RotaryEmbTest("RotaryEmb") diff --git a/test/test_rotary_emb_lightllm.py b/tests/test_rotary_emb_lightllm.py similarity index 71% rename from test/test_rotary_emb_lightllm.py rename to tests/test_rotary_emb_lightllm.py index 680b5203..0ea86628 100644 --- a/test/test_rotary_emb_lightllm.py +++ b/tests/test_rotary_emb_lightllm.py @@ -1,8 +1,7 @@ +# Copyright (c) 2023, DeepLink. + import torch -import torch_dipu -from einops import rearrange -import dipu_ext.ext_ -from DeepLinkExt.ext_apply.lightllm.ext_apply_rotary import deeplink_rotary_emb +import deeplink_ext.cpp_extensions as ext # lightllm的实现 @@ -17,6 +16,15 @@ def torch_rotary_emb(x, cos, sin): return torch.cat((o0, o1), dim=-1) +def deeplink_rotary_emb(x, cos, sin): + seq_len, h, dim = x.shape + cos = cos.view((seq_len, 1, dim // 2)) + sin = sin.view((seq_len, 1, dim // 2)) + output = torch.empty_like(x) + ext.apply_rotary(output, x, cos, sin, False, False) + return output + + # 构造示例输入数据 seq_len = 4 h = 2