diff --git a/fbgemm_gpu/src/split_embeddings_utils/transpose_embedding_input.cu b/fbgemm_gpu/src/split_embeddings_utils/transpose_embedding_input.cu index c0bbf2492b..25f9d0688d 100644 --- a/fbgemm_gpu/src/split_embeddings_utils/transpose_embedding_input.cu +++ b/fbgemm_gpu/src/split_embeddings_utils/transpose_embedding_input.cu @@ -9,7 +9,9 @@ #include "fbgemm_gpu/embedding_backward_template_helpers.cuh" // @manual #include "fbgemm_gpu/ops_utils.h" // @manual #include "fbgemm_gpu/split_embeddings_utils.cuh" // @manual +#ifdef USE_ROCM #include +#endif // clang-format off #include "fbgemm_gpu/cub_namespace_prefix.cuh" // @manual #include @@ -297,7 +299,7 @@ transpose_embedding_input( } { size_t temp_storage_bytes = 0; -#ifdef __HIP_PLATFORM_NVIDIA__ +#ifndef USE_ROCM AT_CUDA_CHECK( FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortPairs( nullptr,