Skip to content

Commit

Permalink
rebase fix
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Oct 14, 2024
1 parent 2632828 commit c31b7e5
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 76 deletions.
1 change: 1 addition & 0 deletions scripts
Submodule scripts added at 0dd7fe
153 changes: 77 additions & 76 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,38 +305,38 @@ def validate_and_update_archs(archs):
)
)
elif not SKIP_CUDA_BUILD and IS_ROCM:
#use codegen get code dispatch
if not os.path.exists("./build"):
os.makedirs("build")

os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d fwd --output_dir build --receipt 2")
os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d fwd_appendkv --output_dir build --receipt 2")
os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d fwd_splitkv --output_dir build --receipt 2")
os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d bwd --output_dir build --receipt 2")

print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])


if USE_TRITON_ROCM:
# Skip C++ extension compilation if using Triton Backend
pass
else:
ck_dir = "csrc/composable_kernel"

#use codegen get code dispatch
if not os.path.exists("./build"):
os.makedirs("build")

os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d fwd --output_dir build --receipt 2")
os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d fwd_appendkv --output_dir build --receipt 2")
os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d fwd_splitkv --output_dir build --receipt 2")
os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d bwd --output_dir build --receipt 2")


# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h
# See https://github.com/pytorch/pytorch/pull/70650
generator_flag = []
torch_dir = torch.__path__[0]
if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")):
generator_flag = ["-DOLD_GENERATOR_PATH"]

check_if_rocm_home_none("flash_attn")
archs = os.getenv("GPU_ARCHS", "native").split(";")
validate_and_update_archs(archs)
check_if_rocm_home_none("flash_attn")
archs = os.getenv("GPU_ARCHS", "native").split(";")
validate_and_update_archs(archs)

cc_flag = [f"--offload-arch={arch}" for arch in archs]

Expand All @@ -346,72 +346,73 @@ def validate_and_update_archs(archs):
if FORCE_CXX11_ABI:
torch._C._GLIBCXX_USE_CXX11_ABI = True

sources = ["csrc/flash_attn_ck/flash_api.cpp",
"csrc/flash_attn_ck/flash_common.cpp",
"csrc/flash_attn_ck/mha_bwd.cpp",
"csrc/flash_attn_ck/mha_fwd_kvcache.cpp",
"csrc/flash_attn_ck/mha_fwd.cpp",
"csrc/flash_attn_ck/mha_varlen_bwd.cpp",
"csrc/flash_attn_ck/mha_varlen_fwd.cpp"] + glob.glob(
f"build/fmha_*wd*.cpp"
)

rename_cpp_to_cu(sources)

renamed_sources = ["csrc/flash_attn_ck/flash_api.cu",
"csrc/flash_attn_ck/flash_common.cu",
"csrc/flash_attn_ck/mha_bwd.cu",
"csrc/flash_attn_ck/mha_fwd_kvcache.cu",
"csrc/flash_attn_ck/mha_fwd.cu",
"csrc/flash_attn_ck/mha_varlen_bwd.cu",
"csrc/flash_attn_ck/mha_varlen_fwd.cu"] + glob.glob(f"build/fmha_*wd*.cu")

cc_flag += ["-O3","-std=c++17",
"-DCK_TILE_FMHA_FWD_FAST_EXP2=1",
"-fgpu-flush-denormals-to-zero",
"-DCK_ENABLE_BF16",
"-DCK_ENABLE_BF8",
"-DCK_ENABLE_FP16",
"-DCK_ENABLE_FP32",
"-DCK_ENABLE_FP64",
"-DCK_ENABLE_FP8",
"-DCK_ENABLE_INT8",
"-DCK_USE_XDL",
"-DUSE_PROF_API=1",
# "-DFLASHATTENTION_DISABLE_BACKWARD",
"-D__HIP_PLATFORM_HCC__=1"]

cc_flag += [f"-DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT={os.environ.get('CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT', 3)}"]

# Imitate https://github.com/ROCm/composable_kernel/blob/c8b6b64240e840a7decf76dfaa13c37da5294c4a/CMakeLists.txt#L190-L214
hip_version = get_hip_version()
if hip_version > Version('5.7.23302'):
cc_flag += ["-fno-offload-uniform-block"]
if hip_version > Version('6.1.40090'):
cc_flag += ["-mllvm", "-enable-post-misched=0"]
if hip_version > Version('6.2.41132'):
cc_flag += ["-mllvm", "-amdgpu-early-inline-all=true",
"-mllvm", "-amdgpu-function-calls=false"]
if hip_version > Version('6.2.41133') and hip_version < Version('6.3.00000'):
cc_flag += ["-mllvm", "-amdgpu-coerce-illegal-types=1"]

extra_compile_args = {
"cxx": ["-O3", "-std=c++17"] + generator_flag,
"nvcc": cc_flag + generator_flag,
}

include_dirs = [
Path(this_dir) / "csrc" / "composable_kernel" / "include",
Path(this_dir) / "csrc" / "composable_kernel" / "library" / "include",
Path(this_dir) / "csrc" / "composable_kernel" / "example" / "ck_tile" / "01_fmha",
]
sources = ["csrc/flash_attn_ck/flash_api.cpp",
"csrc/flash_attn_ck/flash_common.cpp",
"csrc/flash_attn_ck/mha_bwd.cpp",
"csrc/flash_attn_ck/mha_fwd_kvcache.cpp",
"csrc/flash_attn_ck/mha_fwd.cpp",
"csrc/flash_attn_ck/mha_varlen_bwd.cpp",
"csrc/flash_attn_ck/mha_varlen_fwd.cpp"] + glob.glob(
f"build/fmha_*wd*.cpp"
)

ext_modules.append(
CUDAExtension(
name="flash_attn_2_cuda",
sources=renamed_sources,
extra_compile_args=extra_compile_args,
include_dirs=include_dirs,
rename_cpp_to_cu(sources)

renamed_sources = ["csrc/flash_attn_ck/flash_api.cu",
"csrc/flash_attn_ck/flash_common.cu",
"csrc/flash_attn_ck/mha_bwd.cu",
"csrc/flash_attn_ck/mha_fwd_kvcache.cu",
"csrc/flash_attn_ck/mha_fwd.cu",
"csrc/flash_attn_ck/mha_varlen_bwd.cu",
"csrc/flash_attn_ck/mha_varlen_fwd.cu"] + glob.glob(f"build/fmha_*wd*.cu")

cc_flag += ["-O3","-std=c++17",
"-DCK_TILE_FMHA_FWD_FAST_EXP2=1",
"-fgpu-flush-denormals-to-zero",
"-DCK_ENABLE_BF16",
"-DCK_ENABLE_BF8",
"-DCK_ENABLE_FP16",
"-DCK_ENABLE_FP32",
"-DCK_ENABLE_FP64",
"-DCK_ENABLE_FP8",
"-DCK_ENABLE_INT8",
"-DCK_USE_XDL",
"-DUSE_PROF_API=1",
# "-DFLASHATTENTION_DISABLE_BACKWARD",
"-D__HIP_PLATFORM_HCC__=1"]

cc_flag += [f"-DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT={os.environ.get('CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT', 3)}"]

# Imitate https://github.com/ROCm/composable_kernel/blob/c8b6b64240e840a7decf76dfaa13c37da5294c4a/CMakeLists.txt#L190-L214
hip_version = get_hip_version()
if hip_version > Version('5.7.23302'):
cc_flag += ["-fno-offload-uniform-block"]
if hip_version > Version('6.1.40090'):
cc_flag += ["-mllvm", "-enable-post-misched=0"]
if hip_version > Version('6.2.41132'):
cc_flag += ["-mllvm", "-amdgpu-early-inline-all=true",
"-mllvm", "-amdgpu-function-calls=false"]
if hip_version > Version('6.2.41133') and hip_version < Version('6.3.00000'):
cc_flag += ["-mllvm", "-amdgpu-coerce-illegal-types=1"]

extra_compile_args = {
"cxx": ["-O3", "-std=c++17"] + generator_flag,
"nvcc": cc_flag + generator_flag,
}

include_dirs = [
Path(this_dir) / "csrc" / "composable_kernel" / "include",
Path(this_dir) / "csrc" / "composable_kernel" / "library" / "include",
Path(this_dir) / "csrc" / "composable_kernel" / "example" / "ck_tile" / "01_fmha",
]

ext_modules.append(
CUDAExtension(
name="flash_attn_2_cuda",
sources=renamed_sources,
extra_compile_args=extra_compile_args,
include_dirs=include_dirs,
)
)


Expand Down

0 comments on commit c31b7e5

Please sign in to comment.