Skip to content

Commit

Permalink
Use CUDA runtime API to retrieve function pointer to driver API (#1700)
Browse files Browse the repository at this point in the history
* Query pfn to driver api

* use default for older toolkits

---------

Co-authored-by: shunfans <[email protected]>
  • Loading branch information
shunfan-shao and shunfans authored Aug 19, 2024
1 parent f93a691 commit 4dbf5db
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 2 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ set(CUTLASS_NVCC_EMBED_CUBIN ON CACHE BOOL "Embed compiled CUDA kernel binaries
set(CUTLASS_NVCC_EMBED_PTX ON CACHE BOOL "Embed compiled PTX into executables.")
set(CUTLASS_NVCC_KEEP OFF CACHE BOOL "Keep intermediate files generated by NVCC.")
set(CUTLASS_ENABLE_F16C OFF CACHE BOOL "Enable F16C x86 extensions in host code.")
set(CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL OFF CACHE BOOL "Enable CUTLASS to directly call driver API)
################################################################################
#
Expand Down
4 changes: 3 additions & 1 deletion include/cute/atom/copy_traits_sm90_im2col.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@

#include "cute/algorithm/prefetch.hpp"
#include "cutlass/fast_math.h"
#include "cutlass/cuda_host_adapter.hpp"

namespace cute
{

Expand Down Expand Up @@ -450,7 +452,7 @@ make_im2col_tma_copy_desc(
CUtensorMapFloatOOBfill tma_oob_fill = to_CUtensorMapFloatOOBfill(aux_params.oobfill_);
CUtensorMapSwizzle tma_swizzle = TMA::to_CUtensorMapSwizzle(detail::get_tma_swizzle_bits(smem_swizzle));

CUresult encode_result = cuTensorMapEncodeIm2col(
CUresult encode_result = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeIm2col)(
&tma_desc,
tma_format,
num_total_modes,
Expand Down
3 changes: 2 additions & 1 deletion include/cute/atom/copy_traits_sm90_tma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#include <cute/algorithm/prefetch.hpp>

#include <cute/numeric/integral_ratio.hpp>
#include <cutlass/cuda_host_adapter.hpp>

namespace cute
{
Expand Down Expand Up @@ -983,7 +984,7 @@ make_tma_copy_desc(Tensor<GEngine,GLayout> const& gtensor, // The origin

// TMA smem swizzle type
CUtensorMapSwizzle smem_swizzle = TMA::to_CUtensorMapSwizzle(get_tma_swizzle_bits(swizzle));
CUresult result = cuTensorMapEncodeTiled(
CUresult result = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)(
&tma_desc,
tma_format,
tma_dim,
Expand Down
72 changes: 72 additions & 0 deletions include/cutlass/cuda_host_adapter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,78 @@ namespace cutlass {

/////////////////////////////////////////////////////////////////////////////////////////////////

#if !defined(__CUDACC_RTC__)

#include <cudaTypedefs.h>
#include <driver_types.h>

#define CUTLASS_CUDA_DRIVER_STRINGIFY(tok) #tok

#if defined(CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL)

#define CUTLASS_CUDA_DRIVER_WRAPPER_DECL(func, ver) \
template <typename... Args> \
CUresult call_##func(Args... args) { \
return func(args...); \
}

#else // defined(CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL)

#if (__CUDACC_VER_MAJOR__ >= 12 && __CUDACC_VER_MINOR__ >= 5)

#define CUTLASS_CUDA_DRIVER_WRAPPER_DECL(func, ver) \
template <typename... Args> \
CUresult call_##func(Args... args) { \
cudaDriverEntryPointQueryResult cuda_status; \
void* pfn = nullptr; \
cudaError_t cuda_err = cudaGetDriverEntryPointByVersion( \
CUTLASS_CUDA_DRIVER_STRINGIFY(func), \
&pfn, ver, \
cudaEnableDefault, \
&cuda_status); \
if (cuda_status != cudaDriverEntryPointSuccess || \
cuda_err != cudaSuccess) { \
return CUDA_ERROR_UNKNOWN; \
} \
return reinterpret_cast<PFN_##func##_v##ver>(pfn)(args...); \
}

#else

#define CUTLASS_CUDA_DRIVER_WRAPPER_DECL(func, ver) \
template <typename... Args> \
CUresult call_##func(Args... args) { \
cudaDriverEntryPointQueryResult cuda_status; \
void* pfn = nullptr; \
cudaError_t cuda_err = cudaGetDriverEntryPoint( \
CUTLASS_CUDA_DRIVER_STRINGIFY(func), \
&pfn, \
cudaEnableDefault, \
&cuda_status); \
if (cuda_status != cudaDriverEntryPointSuccess || \
cuda_err != cudaSuccess) { \
return CUDA_ERROR_UNKNOWN; \
} \
return reinterpret_cast<PFN_##func>(pfn)(args...); \
}

#endif // (__CUDACC_VER_MAJOR__ >= 12 && __CUDACC_VER_MINOR__ >= 5)

#endif // defined(CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL)

#if (__CUDACC_VER_MAJOR__ >= 12)
CUTLASS_CUDA_DRIVER_WRAPPER_DECL(cuTensorMapEncodeTiled, 12000);
CUTLASS_CUDA_DRIVER_WRAPPER_DECL(cuTensorMapEncodeIm2col, 12000);
#endif

#undef CUTLASS_CUDA_DRIVER_STRINGIFY

#define CUTLASS_CUDA_DRIVER_WRAPPER_CALL(func) cutlass::call_##func

#endif // !defined(__CUDACC_RTC__)

/////////////////////////////////////////////////////////////////////////////////////////////////

/// This class manages runtime CUlaunchAttribute that can be supplied to CudaHostAdapter
/// CudaHostLaunchAttributes will be an empty struct in earlier CTK where CUlaunchAttribute
/// is not introduced.
Expand Down

0 comments on commit 4dbf5db

Please sign in to comment.