From 52f7afeed35a9a3a0f8360196c889b96fe564ff6 Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Fri, 5 Jul 2024 12:06:20 +0200 Subject: [PATCH 1/6] fix : Reduce halo, it should be half of the halo size --- jaxdecomp/_src/halo.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/jaxdecomp/_src/halo.py b/jaxdecomp/_src/halo.py index e61fe8c..86ac046 100644 --- a/jaxdecomp/_src/halo.py +++ b/jaxdecomp/_src/halo.py @@ -220,7 +220,9 @@ def per_shard_impl(x: Array, halo_extents: Tuple[int, int, int], ) if reduce_halo: - halo_x, halo_y, halo_z = halo_extents + # Padding is usally halo_size and the halo_exchange extents are halo_size // 2 + # So the reduction is done on half of the halo_size + halo_x, halo_y, halo_z = [extent * 2 for extent in halo_extents] # Apply corrections along x if halo_x > 0: From 0ae6f5674c5811a95279ec60b13b0405c31167e3 Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Fri, 5 Jul 2024 12:41:27 +0200 Subject: [PATCH 2/6] Allocate using cuda for halo_exchange by default ..allow fallback to XLA --- include/halo.h | 8 +++++++- src/halo.cu | 52 ++++++++++++++++++++++++++++++++++++++------------ 2 files changed, 47 insertions(+), 13 deletions(-) diff --git a/include/halo.h b/include/halo.h index aa40710..558819a 100644 --- a/include/halo.h +++ b/include/halo.h @@ -36,19 +36,25 @@ template class HaloExchange { friend class GridDescriptorManager; public: - HaloExchange() = default; + HaloExchange() : m_Tracer("JAXDECOMP") {} // Grid descriptors are handled by the GridDescriptorManager + // No memory should be cleaned up here everything is handled by the GridDescriptorManager ~HaloExchange() = default; HRESULT get_halo_descriptor(cudecompHandle_t handle, size_t& work_size, haloDescriptor_t& halo_desc); HRESULT halo_exchange(cudecompHandle_t handle, haloDescriptor_t desc, cudaStream_t stream, void** buffers); private: + AsyncLogger m_Tracer; + cudecompGridDesc_t m_GridConfig; cudecompGridDescConfig_t m_GridDescConfig; cudecompPencilInfo_t m_PencilInfo; int64_t m_WorkSize; + void* m_WorkSizeBuffer; + + HRESULT cleanUp(cudecompHandle_t handle); }; } // namespace jaxdecomp diff --git a/src/halo.cu b/src/halo.cu index 27093c5..42fdd7c 100644 --- a/src/halo.cu +++ b/src/halo.cu @@ -24,16 +24,12 @@ HRESULT HaloExchange::get_halo_descriptor(cudecompHandle_t handle, size_ CHECK_CUDECOMP_EXIT( cudecompGetPencilInfo(handle, m_GridConfig, &m_PencilInfo, halo_desc.axis, halo_desc.halo_extents.data())); - cudecompPencilInfo_t no_halo; - - // Get pencil information for the specified axis - CHECK_CUDECOMP_EXIT(cudecompGetPencilInfo(handle, m_GridConfig, &no_halo, halo_desc.axis, nullptr)); - // Get workspace size int64_t workspace_num_elements; CHECK_CUDECOMP_EXIT(cudecompGetHaloWorkspaceSize(handle, m_GridConfig, halo_desc.axis, m_PencilInfo.halo_extents, &workspace_num_elements)); + // TODO(Wassim) Handle complex numbers int64_t dtype_size; if (halo_desc.double_precision) CHECK_CUDECOMP_EXIT(cudecompGetDataTypeSize(CUDECOMP_DOUBLE, &dtype_size)); @@ -42,6 +38,20 @@ HRESULT HaloExchange::get_halo_descriptor(cudecompHandle_t handle, size_ work_size = dtype_size * workspace_num_elements; + static const char* cudalloc = std::getenv("JD_ALLOCATE_WITH_XLA"); + + if (cudalloc = nullptr) { + + CHECK_CUDECOMP_EXIT(cudecompMalloc(handle, m_GridConfig, reinterpret_cast(&m_WorkSizeBuffer), + workspace_num_elements * dtype_size)); + + StartTraceInfo(m_Tracer) << "cudaMalloc will Allocate for Halo_exchange" << std::endl; + + } else { + m_WorkSizeBuffer = nullptr; + StartTraceInfo(m_Tracer) << "XLA will Allocate for Halo_exchange" << std::endl; + } + return S_OK; } @@ -51,21 +61,32 @@ HRESULT HaloExchange::halo_exchange(cudecompHandle_t handle, haloDescrip void* data_d = buffers[0]; void* work_d = buffers[1]; - // desc.axis = 2; + void* buffer_to_user = nullptr; + if (m_WorkSizeBuffer != nullptr) { + // CUDA allocate buffer and managed by me + buffer_to_user = m_WorkSizeBuffer; + } else { + // XLA allocate buffer and managed by XLA + buffer_to_user = work_d; + } + // Perform halo exchange along the three dimensions for (int i = 0; i < 3; ++i) { switch (desc.axis) { case 0: - CHECK_CUDECOMP_EXIT(cudecompUpdateHalosX(handle, m_GridConfig, data_d, work_d, get_cudecomp_datatype(real_t(0)), - m_PencilInfo.halo_extents, desc.halo_periods.data(), i, stream)); + CHECK_CUDECOMP_EXIT(cudecompUpdateHalosX(handle, m_GridConfig, data_d, buffer_to_user, + get_cudecomp_datatype(real_t(0)), m_PencilInfo.halo_extents, + desc.halo_periods.data(), i, stream)); break; case 1: - CHECK_CUDECOMP_EXIT(cudecompUpdateHalosY(handle, m_GridConfig, data_d, work_d, get_cudecomp_datatype(real_t(0)), - m_PencilInfo.halo_extents, desc.halo_periods.data(), i, stream)); + CHECK_CUDECOMP_EXIT(cudecompUpdateHalosY(handle, m_GridConfig, data_d, buffer_to_user, + get_cudecomp_datatype(real_t(0)), m_PencilInfo.halo_extents, + desc.halo_periods.data(), i, stream)); break; case 2: - CHECK_CUDECOMP_EXIT(cudecompUpdateHalosZ(handle, m_GridConfig, data_d, work_d, get_cudecomp_datatype(real_t(0)), - m_PencilInfo.halo_extents, desc.halo_periods.data(), i, stream)); + CHECK_CUDECOMP_EXIT(cudecompUpdateHalosZ(handle, m_GridConfig, data_d, buffer_to_user, + get_cudecomp_datatype(real_t(0)), m_PencilInfo.halo_extents, + desc.halo_periods.data(), i, stream)); break; } } @@ -73,6 +94,13 @@ HRESULT HaloExchange::halo_exchange(cudecompHandle_t handle, haloDescrip return S_OK; }; +template HRESULT HaloExchange::cleanUp(cudecompHandle_t handle) { + // Destroy the memory buffer allocate in case of cudaMalloc + // In case of XLA allocation, this buffer is nullptr + if (m_WorkSizeBuffer != nullptr) { CHECK_CUDECOMP_EXIT(cudecompFree(handle, m_GridConfig, m_WorkSizeBuffer)); } + return S_OK; +} + template class HaloExchange; template class HaloExchange; } // namespace jaxdecomp From e3d4c699ac2da78888ff10ddcc6d5e563de48f2f Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Fri, 5 Jul 2024 12:41:58 +0200 Subject: [PATCH 3/6] Clean up all GridDescriptor and halo_exchange memory at finalize --- include/grid_descriptor_mgr.h | 2 +- src/grid_descriptor_mgr.cc | 61 +++++++++++++++++++++++++++++++---- 2 files changed, 56 insertions(+), 7 deletions(-) diff --git a/include/grid_descriptor_mgr.h b/include/grid_descriptor_mgr.h index cce9b84..c4ae180 100644 --- a/include/grid_descriptor_mgr.h +++ b/include/grid_descriptor_mgr.h @@ -52,7 +52,7 @@ class GridDescriptorManager { AsyncLogger m_Tracer; bool isInitialized = false; - + int isMPIalreadyInitialized = false; cudecompHandle_t m_Handle; std::unordered_map>, std::hash, std::equal_to<>> diff --git a/src/grid_descriptor_mgr.cc b/src/grid_descriptor_mgr.cc index 40c6a25..fa0c946 100644 --- a/src/grid_descriptor_mgr.cc +++ b/src/grid_descriptor_mgr.cc @@ -19,9 +19,8 @@ GridDescriptorManager::GridDescriptorManager() : m_Tracer("JAXDECOMP") { MPI_Comm mpi_comm = MPI_COMM_WORLD; // Check if MPI has already been initialized by the user (maybe with mpi4py) - int is_initialized; - CHECK_MPI_EXIT(MPI_Initialized(&is_initialized)); - if (!is_initialized) { CHECK_MPI_EXIT(MPI_Init(nullptr, nullptr)); } + CHECK_MPI_EXIT(MPI_Initialized(&isMPIalreadyInitialized)); + if (!isMPIalreadyInitialized) { CHECK_MPI_EXIT(MPI_Init(nullptr, nullptr)); } // Initialize cuDecomp CHECK_CUDECOMP_EXIT(cudecompInit(&m_Handle, mpi_comm)); isInitialized = true; @@ -144,11 +143,12 @@ HRESULT GridDescriptorManager::createTransposeExecutor(transposeDescriptor& desc return hr; } +// TODO(Wassim) : This can be cleanup using some polymorphism void GridDescriptorManager::finalize() { if (!isInitialized) return; StartTraceInfo(m_Tracer) << "JaxDecomp shut down" << std::endl; - // Destroy grid descriptors + // Destroy grid descriptors for FFTs for (auto& descriptor : m_Descriptors64) { auto& executor = descriptor.second; // TODO(wassim) : Cleanup cudecomp resources @@ -175,13 +175,62 @@ void GridDescriptorManager::finalize() { executor->clearPlans(); } + // Destroy Halo descriptors + for (auto& descriptor : m_HaloDescriptors64) { + auto& executor = descriptor.second; + // Cleanup cudecomp resources + // CHECK_CUDECOMP_EXIT(cudecompFree(handle, grid_desc_c, work)); This can + // be used instead of requesting XLA to allocate the memory + cudecompResult_t err = cudecompGridDescDestroy(m_Handle, executor->m_GridConfig); + if (CUDECOMP_RESULT_SUCCESS != err) { + StartTraceInfo(m_Tracer) << "CUDECOMP error.at exit " << err << ")" << std::endl; + } + executor->cleanUp(m_Handle); + } + + for (auto& descriptor : m_HaloDescriptors32) { + auto& executor = descriptor.second; + // Cleanup cudecomp resources + // CHECK_CUDECOMP_EXIT(cudecompFree(handle, grid_desc_c, work)); This can + // be used instead of requesting XLA to allocate the memory + cudecompResult_t err = cudecompGridDescDestroy(m_Handle, executor->m_GridConfig); + if (CUDECOMP_RESULT_SUCCESS != err) { + StartTraceInfo(m_Tracer) << "CUDECOMP error.at exit " << err << ")" << std::endl; + } + executor->cleanUp(m_Handle); + } + + // Destroy Transpose descriptors + for (auto& descriptor : m_TransposeDescriptors64) { + auto& executor = descriptor.second; + // Cleanup cudecomp resources + // CHECK_CUDECOMP_EXIT(cudecompFree(handle, grid_desc_c, work)); This can + // be used instead of requesting XLA to allocate the memory + cudecompResult_t err = cudecompGridDescDestroy(m_Handle, executor->m_GridConfig); + if (CUDECOMP_RESULT_SUCCESS != err) { + StartTraceInfo(m_Tracer) << "CUDECOMP error.at exit " << err << ")" << std::endl; + } + } + + for (auto& descriptor : m_TransposeDescriptors32) { + auto& executor = descriptor.second; + // Cleanup cudecomp resources + // CHECK_CUDECOMP_EXIT(cudecompFree(handle, grid_desc_c, work)); This can + // be used instead of requesting XLA to allocate the memory + cudecompResult_t err = cudecompGridDescDestroy(m_Handle, executor->m_GridConfig); + if (CUDECOMP_RESULT_SUCCESS != err) { + StartTraceInfo(m_Tracer) << "CUDECOMP error.at exit " << err << ")" << std::endl; + } + } + // TODO(wassim) : Cleanup cudecomp resources // there is an issue with mpi4py calling finalize at py_exit before this cudecompFinalize(m_Handle); // Clean finish CHECK_CUDA_EXIT(cudaDeviceSynchronize()); - // MPI is finalized by the mpi4py runtime (I wish it wasn't) - // CHECK_MPI_EXIT(MPI_Finalize()); + // If jaxDecomp initialized MPI finalize it + // Otherwise mpi4py will finalize its own MPI WORLD Communicator + if (!isMPIalreadyInitialized) { CHECK_MPI_EXIT(MPI_Finalize()); } isInitialized = false; } From 8ca8f3675f5f39b33f7adce2a2096f7213f14791 Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Fri, 5 Jul 2024 12:47:09 +0200 Subject: [PATCH 4/6] Do not allocate with XLA if not selected --- jaxdecomp/_src/halo.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/jaxdecomp/_src/halo.py b/jaxdecomp/_src/halo.py index 86ac046..f0f6816 100644 --- a/jaxdecomp/_src/halo.py +++ b/jaxdecomp/_src/halo.py @@ -1,3 +1,4 @@ +import os from functools import partial from typing import Tuple @@ -131,6 +132,11 @@ def lowering(ctx, x: Array, halo_extents: Tuple[int, int, int], config, is_double, halo_extents[::-1], halo_periods[::-1], 0) layout = tuple(range(n - 1, -1, -1)) + # If XLA is not the selected allocator, then allocate a workspace with size 1 + # TODO(Wassim) : Eventually, Only cuda should allocate .. this will be removed in the future + if os.environ.get("JD_ALLOCATE_WITH_XLA", "0") == "0": + workspace_size = 1 + workspace = mlir.full_like_aval( ctx, 0, jax.core.ShapedArray(shape=[workspace_size], dtype=np.byte)) From 75e7032e82e435cf576ec7b393de36d8be697e0e Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Fri, 19 Jul 2024 16:12:56 +0200 Subject: [PATCH 5/6] set worksize correctly when using Halo --- src/halo.cu | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/halo.cu b/src/halo.cu index 42fdd7c..a15082b 100644 --- a/src/halo.cu +++ b/src/halo.cu @@ -36,7 +36,8 @@ HRESULT HaloExchange::get_halo_descriptor(cudecompHandle_t handle, size_ else CHECK_CUDECOMP_EXIT(cudecompGetDataTypeSize(CUDECOMP_FLOAT, &dtype_size)); - work_size = dtype_size * workspace_num_elements; + m_WorkSize = dtype_size * workspace_num_elements; + work_size = m_WorkSize; static const char* cudalloc = std::getenv("JD_ALLOCATE_WITH_XLA"); From d59c63841474588249711f45296acf15ba63fbdb Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Fri, 19 Jul 2024 16:43:21 +0200 Subject: [PATCH 6/6] revert cuda allocation for halo, only XLA is done --- include/halo.h | 2 -- jaxdecomp/_src/halo.py | 5 ----- src/halo.cu | 43 ++++++++---------------------------------- 3 files changed, 8 insertions(+), 42 deletions(-) diff --git a/include/halo.h b/include/halo.h index 558819a..0e10708 100644 --- a/include/halo.h +++ b/include/halo.h @@ -52,8 +52,6 @@ template class HaloExchange { cudecompPencilInfo_t m_PencilInfo; int64_t m_WorkSize; - void* m_WorkSizeBuffer; - HRESULT cleanUp(cudecompHandle_t handle); }; diff --git a/jaxdecomp/_src/halo.py b/jaxdecomp/_src/halo.py index e8b1e9b..232d8c7 100644 --- a/jaxdecomp/_src/halo.py +++ b/jaxdecomp/_src/halo.py @@ -123,11 +123,6 @@ def lowering(ctx, x: Array, halo_extents: Tuple[int, int, int], config, is_double, halo_extents[::-1], halo_periods[::-1], 0) layout = tuple(range(n - 1, -1, -1)) - # If XLA is not the selected allocator, then allocate a workspace with size 1 - # TODO(Wassim) : Eventually, Only cuda should allocate .. this will be removed in the future - if os.environ.get("JD_ALLOCATE_WITH_XLA", "0") == "0": - workspace_size = 1 - workspace = mlir.full_like_aval( ctx, 0, jax.core.ShapedArray(shape=[workspace_size], dtype=np.byte)) diff --git a/src/halo.cu b/src/halo.cu index a15082b..623ed60 100644 --- a/src/halo.cu +++ b/src/halo.cu @@ -39,20 +39,6 @@ HRESULT HaloExchange::get_halo_descriptor(cudecompHandle_t handle, size_ m_WorkSize = dtype_size * workspace_num_elements; work_size = m_WorkSize; - static const char* cudalloc = std::getenv("JD_ALLOCATE_WITH_XLA"); - - if (cudalloc = nullptr) { - - CHECK_CUDECOMP_EXIT(cudecompMalloc(handle, m_GridConfig, reinterpret_cast(&m_WorkSizeBuffer), - workspace_num_elements * dtype_size)); - - StartTraceInfo(m_Tracer) << "cudaMalloc will Allocate for Halo_exchange" << std::endl; - - } else { - m_WorkSizeBuffer = nullptr; - StartTraceInfo(m_Tracer) << "XLA will Allocate for Halo_exchange" << std::endl; - } - return S_OK; } @@ -62,32 +48,20 @@ HRESULT HaloExchange::halo_exchange(cudecompHandle_t handle, haloDescrip void* data_d = buffers[0]; void* work_d = buffers[1]; - void* buffer_to_user = nullptr; - if (m_WorkSizeBuffer != nullptr) { - // CUDA allocate buffer and managed by me - buffer_to_user = m_WorkSizeBuffer; - } else { - // XLA allocate buffer and managed by XLA - buffer_to_user = work_d; - } - // Perform halo exchange along the three dimensions for (int i = 0; i < 3; ++i) { switch (desc.axis) { case 0: - CHECK_CUDECOMP_EXIT(cudecompUpdateHalosX(handle, m_GridConfig, data_d, buffer_to_user, - get_cudecomp_datatype(real_t(0)), m_PencilInfo.halo_extents, - desc.halo_periods.data(), i, stream)); + CHECK_CUDECOMP_EXIT(cudecompUpdateHalosX(handle, m_GridConfig, data_d, work_d, get_cudecomp_datatype(real_t(0)), + m_PencilInfo.halo_extents, desc.halo_periods.data(), i, stream)); break; case 1: - CHECK_CUDECOMP_EXIT(cudecompUpdateHalosY(handle, m_GridConfig, data_d, buffer_to_user, - get_cudecomp_datatype(real_t(0)), m_PencilInfo.halo_extents, - desc.halo_periods.data(), i, stream)); + CHECK_CUDECOMP_EXIT(cudecompUpdateHalosY(handle, m_GridConfig, data_d, work_d, get_cudecomp_datatype(real_t(0)), + m_PencilInfo.halo_extents, desc.halo_periods.data(), i, stream)); break; case 2: - CHECK_CUDECOMP_EXIT(cudecompUpdateHalosZ(handle, m_GridConfig, data_d, buffer_to_user, - get_cudecomp_datatype(real_t(0)), m_PencilInfo.halo_extents, - desc.halo_periods.data(), i, stream)); + CHECK_CUDECOMP_EXIT(cudecompUpdateHalosZ(handle, m_GridConfig, data_d, work_d, get_cudecomp_datatype(real_t(0)), + m_PencilInfo.halo_extents, desc.halo_periods.data(), i, stream)); break; } } @@ -96,9 +70,8 @@ HRESULT HaloExchange::halo_exchange(cudecompHandle_t handle, haloDescrip }; template HRESULT HaloExchange::cleanUp(cudecompHandle_t handle) { - // Destroy the memory buffer allocate in case of cudaMalloc - // In case of XLA allocation, this buffer is nullptr - if (m_WorkSizeBuffer != nullptr) { CHECK_CUDECOMP_EXIT(cudecompFree(handle, m_GridConfig, m_WorkSizeBuffer)); } + // XLA is doing the allocation + // nothing to clean up return S_OK; }