From c31b7e596536b33b98f554d3285046b872ab5a71 Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Mon, 14 Oct 2024 11:24:20 -0500 Subject: [PATCH] rebase fix --- scripts | 1 + setup.py | 153 ++++++++++++++++++++++++++++--------------------------- 2 files changed, 78 insertions(+), 76 deletions(-) create mode 160000 scripts diff --git a/scripts b/scripts new file mode 160000 index 000000000..0dd7fe36a --- /dev/null +++ b/scripts @@ -0,0 +1 @@ +Subproject commit 0dd7fe36a10eddf9892d7bf2dada21513f858a4a diff --git a/setup.py b/setup.py index 0f7e65ab7..be32ad809 100644 --- a/setup.py +++ b/setup.py @@ -305,28 +305,28 @@ def validate_and_update_archs(archs): ) ) elif not SKIP_CUDA_BUILD and IS_ROCM: - #use codegen get code dispatch - if not os.path.exists("./build"): - os.makedirs("build") - - os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d fwd --output_dir build --receipt 2") - os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d fwd_appendkv --output_dir build --receipt 2") - os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d fwd_splitkv --output_dir build --receipt 2") - os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d bwd --output_dir build --receipt 2") print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) TORCH_MAJOR = int(torch.__version__.split(".")[0]) TORCH_MINOR = int(torch.__version__.split(".")[1]) + if USE_TRITON_ROCM: # Skip C++ extension compilation if using Triton Backend pass else: ck_dir = "csrc/composable_kernel" + #use codegen get code dispatch + if not os.path.exists("./build"): + os.makedirs("build") + os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d fwd --output_dir build --receipt 2") + os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d fwd_appendkv --output_dir build --receipt 2") + os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d fwd_splitkv --output_dir build --receipt 2") os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d bwd --output_dir build --receipt 2") + # Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h # See https://github.com/pytorch/pytorch/pull/70650 generator_flag = [] @@ -334,9 +334,9 @@ def validate_and_update_archs(archs): if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")): generator_flag = ["-DOLD_GENERATOR_PATH"] - check_if_rocm_home_none("flash_attn") - archs = os.getenv("GPU_ARCHS", "native").split(";") - validate_and_update_archs(archs) + check_if_rocm_home_none("flash_attn") + archs = os.getenv("GPU_ARCHS", "native").split(";") + validate_and_update_archs(archs) cc_flag = [f"--offload-arch={arch}" for arch in archs] @@ -346,72 +346,73 @@ def validate_and_update_archs(archs): if FORCE_CXX11_ABI: torch._C._GLIBCXX_USE_CXX11_ABI = True - sources = ["csrc/flash_attn_ck/flash_api.cpp", - "csrc/flash_attn_ck/flash_common.cpp", - "csrc/flash_attn_ck/mha_bwd.cpp", - "csrc/flash_attn_ck/mha_fwd_kvcache.cpp", - "csrc/flash_attn_ck/mha_fwd.cpp", - "csrc/flash_attn_ck/mha_varlen_bwd.cpp", - "csrc/flash_attn_ck/mha_varlen_fwd.cpp"] + glob.glob( - f"build/fmha_*wd*.cpp" - ) - - rename_cpp_to_cu(sources) - - renamed_sources = ["csrc/flash_attn_ck/flash_api.cu", - "csrc/flash_attn_ck/flash_common.cu", - "csrc/flash_attn_ck/mha_bwd.cu", - "csrc/flash_attn_ck/mha_fwd_kvcache.cu", - "csrc/flash_attn_ck/mha_fwd.cu", - "csrc/flash_attn_ck/mha_varlen_bwd.cu", - "csrc/flash_attn_ck/mha_varlen_fwd.cu"] + glob.glob(f"build/fmha_*wd*.cu") - - cc_flag += ["-O3","-std=c++17", - "-DCK_TILE_FMHA_FWD_FAST_EXP2=1", - "-fgpu-flush-denormals-to-zero", - "-DCK_ENABLE_BF16", - "-DCK_ENABLE_BF8", - "-DCK_ENABLE_FP16", - "-DCK_ENABLE_FP32", - "-DCK_ENABLE_FP64", - "-DCK_ENABLE_FP8", - "-DCK_ENABLE_INT8", - "-DCK_USE_XDL", - "-DUSE_PROF_API=1", - # "-DFLASHATTENTION_DISABLE_BACKWARD", - "-D__HIP_PLATFORM_HCC__=1"] - - cc_flag += [f"-DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT={os.environ.get('CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT', 3)}"] - - # Imitate https://github.com/ROCm/composable_kernel/blob/c8b6b64240e840a7decf76dfaa13c37da5294c4a/CMakeLists.txt#L190-L214 - hip_version = get_hip_version() - if hip_version > Version('5.7.23302'): - cc_flag += ["-fno-offload-uniform-block"] - if hip_version > Version('6.1.40090'): - cc_flag += ["-mllvm", "-enable-post-misched=0"] - if hip_version > Version('6.2.41132'): - cc_flag += ["-mllvm", "-amdgpu-early-inline-all=true", - "-mllvm", "-amdgpu-function-calls=false"] - if hip_version > Version('6.2.41133') and hip_version < Version('6.3.00000'): - cc_flag += ["-mllvm", "-amdgpu-coerce-illegal-types=1"] - - extra_compile_args = { - "cxx": ["-O3", "-std=c++17"] + generator_flag, - "nvcc": cc_flag + generator_flag, - } - - include_dirs = [ - Path(this_dir) / "csrc" / "composable_kernel" / "include", - Path(this_dir) / "csrc" / "composable_kernel" / "library" / "include", - Path(this_dir) / "csrc" / "composable_kernel" / "example" / "ck_tile" / "01_fmha", - ] + sources = ["csrc/flash_attn_ck/flash_api.cpp", + "csrc/flash_attn_ck/flash_common.cpp", + "csrc/flash_attn_ck/mha_bwd.cpp", + "csrc/flash_attn_ck/mha_fwd_kvcache.cpp", + "csrc/flash_attn_ck/mha_fwd.cpp", + "csrc/flash_attn_ck/mha_varlen_bwd.cpp", + "csrc/flash_attn_ck/mha_varlen_fwd.cpp"] + glob.glob( + f"build/fmha_*wd*.cpp" + ) - ext_modules.append( - CUDAExtension( - name="flash_attn_2_cuda", - sources=renamed_sources, - extra_compile_args=extra_compile_args, - include_dirs=include_dirs, + rename_cpp_to_cu(sources) + + renamed_sources = ["csrc/flash_attn_ck/flash_api.cu", + "csrc/flash_attn_ck/flash_common.cu", + "csrc/flash_attn_ck/mha_bwd.cu", + "csrc/flash_attn_ck/mha_fwd_kvcache.cu", + "csrc/flash_attn_ck/mha_fwd.cu", + "csrc/flash_attn_ck/mha_varlen_bwd.cu", + "csrc/flash_attn_ck/mha_varlen_fwd.cu"] + glob.glob(f"build/fmha_*wd*.cu") + + cc_flag += ["-O3","-std=c++17", + "-DCK_TILE_FMHA_FWD_FAST_EXP2=1", + "-fgpu-flush-denormals-to-zero", + "-DCK_ENABLE_BF16", + "-DCK_ENABLE_BF8", + "-DCK_ENABLE_FP16", + "-DCK_ENABLE_FP32", + "-DCK_ENABLE_FP64", + "-DCK_ENABLE_FP8", + "-DCK_ENABLE_INT8", + "-DCK_USE_XDL", + "-DUSE_PROF_API=1", + # "-DFLASHATTENTION_DISABLE_BACKWARD", + "-D__HIP_PLATFORM_HCC__=1"] + + cc_flag += [f"-DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT={os.environ.get('CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT', 3)}"] + + # Imitate https://github.com/ROCm/composable_kernel/blob/c8b6b64240e840a7decf76dfaa13c37da5294c4a/CMakeLists.txt#L190-L214 + hip_version = get_hip_version() + if hip_version > Version('5.7.23302'): + cc_flag += ["-fno-offload-uniform-block"] + if hip_version > Version('6.1.40090'): + cc_flag += ["-mllvm", "-enable-post-misched=0"] + if hip_version > Version('6.2.41132'): + cc_flag += ["-mllvm", "-amdgpu-early-inline-all=true", + "-mllvm", "-amdgpu-function-calls=false"] + if hip_version > Version('6.2.41133') and hip_version < Version('6.3.00000'): + cc_flag += ["-mllvm", "-amdgpu-coerce-illegal-types=1"] + + extra_compile_args = { + "cxx": ["-O3", "-std=c++17"] + generator_flag, + "nvcc": cc_flag + generator_flag, + } + + include_dirs = [ + Path(this_dir) / "csrc" / "composable_kernel" / "include", + Path(this_dir) / "csrc" / "composable_kernel" / "library" / "include", + Path(this_dir) / "csrc" / "composable_kernel" / "example" / "ck_tile" / "01_fmha", + ] + + ext_modules.append( + CUDAExtension( + name="flash_attn_2_cuda", + sources=renamed_sources, + extra_compile_args=extra_compile_args, + include_dirs=include_dirs, + ) )