Skip to content

Commit

Permalink
update paper
Browse files Browse the repository at this point in the history
  • Loading branch information
ASKabalan committed Jul 22, 2024
1 parent bb8f6b6 commit 1396d48
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 67 deletions.
Binary file added joss-paper/assets/strong_scaling.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added joss-paper/assets/weak_scaling.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
118 changes: 51 additions & 67 deletions joss-paper/paper.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ affiliations:
index: 3
date: 26 June 2024
bibliography: paper.bib
header-includes:
- \usepackage{algorithm2e}

---

Expand All @@ -41,7 +39,7 @@ JAX [@JAX] has seen widespread adoption in both machine learning and scientific

Recently, JAX has made a major push towards simplified SPMD programming, with the unification of the JAX array API and the introduction of several powerful APIs, such as `pjit`, `shard_map`, and `custom_partitioning`. However, not all native JAX operations have specialized distribution strategies, and `pjitting` a program can lead to excessive communication overhead for some operations, particularly the 3D Fast Fourier Transform (FFT), which is one of the most critical and widely used algorithms in scientific computing. Distributed FFTs are essential for many simulation and solvers, especially in fields like cosmology and fluid dynamics, where large-scale data processing is required.

To address these limitations, we introduce jaxDecomp, a JAX library that wraps NVIDIA's cuDecomp domain decomposition library [@cuDecomp]. jaxDecomp provides JAX primitives with highly efficient CUDA implementations for key operations such as 3D FFTs and halo exchanges. By integrating seamlessly with JAX, jaxDecomp supports running on multiple GPUs and nodes, enabling large-scale, distributed scientific computations. Implemented as JAX primitives, jaxDecomp builds directly on top of the distributed Array strategy in JAX and is compatible with JAX transformations such as `jax.grad` and `jax.jit`, ensuring fast execution and differentiability with a pythonic, easy-to-use interface. Using cuDecomp, jaxDecomp can switch between NCCL and CUDA-Aware MPI for distributed array transpose operations, allowing it to best fit the specific HPC cluster configuration.
To address these limitations, we introduce jaxDecomp, a JAX library that wraps NVIDIA's cuDecomp domain decomposition library [@cuDecomp]. jaxDecomp provides JAX primitives with highly efficient CUDA implementations for key operations such as 3D FFTs and halo exchanges. By integrating seamlessly with JAX, jaxDecomp supports running on multiple GPUs and nodes, enabling large-scale, distributed scientific computations. Implemented as JAX primitives, jaxDecomp builds directly on top of the distributed Array strategy in JAX and is compatible with JAX transformations such as `jax.grad` and `jax.jit`, ensuring fast execution and differentiability with a python, easy-to-use interface. Using cuDecomp, jaxDecomp can switch between `NCCL`, `CUDA-Aware MPI` and `NVSHMEM` for distributed array transpose operations, allowing it to best fit the specific HPC cluster configuration.

# Statement of Need

Expand All @@ -51,83 +49,65 @@ In scientific applications such as particle mesh (PM) simulations for cosmology,

While it is technically feasible to implement distributed FFTs using native JAX, there are significant benefits to using jaxDecomp. Although the performance difference may be marginal, jaxDecomp offers several advantages that make it a valuable tool for HPC applications. Firstly, jaxDecomp provides the ability to easily switch backends between NCCL, MPI, and NVSHMEM, optimizing performance based on the specific HPC cluster configuration. Secondly, jaxDecomp performs operations in place, which is more memory-efficient, minimizing the use of intermediate memory and enhancing overall performance. This is crucial for memory-bound codes such as cosmological simulations.

## Implementation
# Implementation

jaxDecomp utilizes JAX's Custom JAX primitive to wrap cuDecomp operations, enabling the integration of CUDA code within the HLO graph via XLA's custom_call. By leveraging the recent custom_partitioning JAX API, partitioning information is embedded in the HLO graph. This approach transparently maintains the state of cuDecomp, including the processor grid and allocated memory for the user.
## Distributed FFT Algorithm

### Domain Decomposition
The following steps outline the distributed FFT algorithm in jaxDecomp, which uses 2D domain decomposition to distribute 3D data across GPUs.

jaxDecomp supports domain decomposition strategies such as 1D and 2D (pencil) decompositions. In 1D decomposition, arrays are decomposed along a single axis, while in 2D decomposition, arrays are decomposed into pencils (slabs). This flexibility allows for efficient distribution of data across multiple GPUs while preserving locality.
+--------------------+------------------------------------------------------+---------------------------------------------------------------------------------+
| Steps | Local Operation | Global Operation |
+====================+======================================================+=================================================================================+
| FFT along X | Perform batched FFT along the X axis. | - |
+--------------------+------------------------------------------------------+---------------------------------------------------------------------------------+
| Transpose X to Y | Local transpose to $Y \times X \times Z$ | All-to-all communication to concatenate $Y$: $Y \times \frac{X}{P_y} \times \frac{Z}{P_z}$ |
+--------------------+------------------------------------------------------+---------------------------------------------------------------------------------+
| FFT along Y | Perform batched FFT along the Y axis. | - |
+--------------------+------------------------------------------------------+---------------------------------------------------------------------------------+
| Transpose Y to Z | Local transpose to $Z \times X \times Y$ | All-to-all communication to concatenate $Z$: $Z \times \frac{X}{P_z} \times \frac{Y}{P_y}$ |
+--------------------+------------------------------------------------------+---------------------------------------------------------------------------------+
| FFT along Z | Perform batched FFT along the Z axis. | - |
+--------------------+------------------------------------------------------+---------------------------------------------------------------------------------+

------------------
1. **Distributed FFT Algorithm**:
------------------
Distribute 3D data across GPUs using 2D domain decomposition.
$X \times Y \times Z$ data is distributed across $P_x \times P_y$ GPUs.\
Each transpose includes a local cyclic transposition of axes, which implies a transposition of the decomposition grid. This process involves both a local transposition and a processor grid transposition at each step.

$X \times \frac{Y}{P_y} \times \frac{Z}{P_z}$
In order to capture the changes and return the right output to JAX in a transparent way, the local transpositions are described in the lowering of the primitive and the GPU grid transpositions are described in the `infer_sharding_from_operands` rule, which is part of JAX's `custom_partitioning` API.

First FFT along X:
![Visualization of the distributed FFT process in jaxDecomp](assets/fft.svg)
*Figure: Visualization of the distributed FFT process in jaxDecomp*

$FFT (X \times \frac{Y}{P_y} \times \frac{Z}{P_z})$
## Distributed Halo Exchange

Transpose X to Y:
The halo exchange is a crucial step in distributed programming. It allows the transfer of data on the edges of each slice to the adjacent slice, ensuring data consistency across the boundaries of distributed domains.

local : X split on $P_y$ local transpose to $y \times x \times z$
global : all-to-all communication to concatenate along $Y$
Many applications in high-performance computing (HPC) use domain decomposition to distribute the workload among different processing elements. These applications, such as cosmological simulations, stencil computations, and PDE solvers, require the halo regions to be updated with data from neighboring regions. This process, often referred to as a halo update, is implemented using MPI (Message Passing Interface) on large machines.

$Y \times \frac{X}{P_y} \times \frac{Z}{P_z}$
Using cuDecomp, we can also change the communication backend to NCCL, MPI, or NVSHMEM.

Second FFT along Y:
$ FFT (Y \times \frac{X}{P_y} \times \frac{Z}{P_z})$
### Halo Exchange Process

Transpose Y to Z:
For each axis, a slice of data of size equal to the halo extent is exchanged between neighboring subdomains.

local : Y split on $P_z$ local transpose to $z \times x \times y$
+--------------------------------------------------+-----------------------------------------------------+
| Send | Receive |
+==================================================+=====================================================+
| $[ \text{Size} - 2 \times \text{Halo} \rightarrow \text{Size} - \text{Halo} ]$ is sent to the next slice | $[ 0 \rightarrow \text{Halo} ]$ is received from the previous slice |
+--------------------------------------------------+-----------------------------------------------------+
| $[ \text{Halo} \rightarrow 2 \times \text{Halo} ]$ is sent to the previous slice | $[ \text{Size} - \text{Halo} \rightarrow \text{Size} ]$ is received from the next slice |
+--------------------------------------------------+-----------------------------------------------------+

global : all-to-all communication to concatenate along $Z$
![Visualization of the distributed halo exchange process in jaxDecomp](assets/halo_exchange.svg)
*Figure: Visualization of the distributed halo exchange process in jaxDecomp*

$Z \times \frac{X}{P_z} \times \frac{Y}{P_y}$
### Efficient State Management in jaxDecomp

Third FFT along Z:
jaxDecomp effectively manages the metadata and resources required for cuDecomp operations, ensuring both efficiency and performance. This is achieved through a caching mechanism that stores the necessary information for transpositions and halo exchanges, as well as cuFFT plans.

$FFT (Z \times \frac{X}{P_z} \times \frac{Y}{P_y})$
jaxDecomp caches the metadata that cuDecomp uses for transpositions and halo exchanges, and also caches the cuFFT plans. All this data is created efficiently and lazily (i.e., it is generated only when needed during JAX's just-in-time (JIT) compilation of functions) and stored for subsequent use. This approach ensures that resources are allocated only when necessary, reducing overhead and improving performance.

![](assets/fft.svg)

------------------
2. **Distributed Halo Exchange**:
------------------

Algorith is

define halo extent $H_x, H_y, H_z$

UpdateHalo_X :

From $X - 2 \times H_x$ to $X - H_x$ send to right, receive from left

### Distributed Halo Exchange

In jaxDecomp, the distributed halo exchange mechanism efficiently facilitates boundary updates essential for scientific computing algorithms and simulations. This operation involves padding each slice of simulation data and executing a halo exchange to synchronize information across the edges of local domains distributed across GPUs. By exchanging data at the boundaries, jaxDecomp ensures seamless communication and consistency between adjacent domains, which is crucial for achieving accurate and reliable results in distributed simulations on HPC clusters.

### Conclusion

jaxDecomp effectively bridges the gap in JAX's distributed computing capabilities by providing a highly efficient, memory-optimized, and differentiable solution for 3D FFTs and halo exchanges. This integration ensures that scientific computations can be performed at scale, leveraging the power of modern HPC clusters.

### Distributed Halo Exchange

In jaxDecomp, the distributed halo exchange mechanism efficiently facilitates boundary updates essential for scientific computing algorithms and simulations. This operation involves padding each slice of simulation data and executing a halo exchange to synchronize information across the edges of local domains distributed across GPUs. By exchanging data at the boundaries, jaxDecomp ensures seamless communication and consistency between adjacent domains, which is crucial for achieving accurate and reliable results in distributed simulations on HPC clusters.

### Conclusion

jaxDecomp effectively bridges the gap in JAX's distributed computing capabilities by providing a highly efficient, memory-optimized, and differentiable solution for 3D FFTs and halo exchanges. This integration ensures that scientific computations can be performed at scale, leveraging the power of modern HPC clusters.


# Distributed Halo Exchange

In jaxDecomp, the distributed halo exchange mechanism efficiently facilitates boundary updates essential for scientific computing algorithms and simulations. This operation involves padding each slice of simulation data and executing a halo exchange to synchronize information across the edges of local domains distributed across GPUs. By exchanging data at the boundaries, jaxDecomp ensures seamless communication and consistency between adjacent domains, crucial for achieving accurate and reliable results in distributed simulations on HPC clusters.
The cached data is properly destroyed at the end of the session, ensuring that no resources are wasted or leaked.

Additionally, jaxDecomp opportunistically creates inverse FFT (IFFT) plans when the FFT is JIT compiled. This leads to improved performance, as the IFFT plans are readily available for use, resulting in a 5x speedup in the IFFT JIT compilation process.


# API description
Expand Down Expand Up @@ -211,17 +191,21 @@ def potential(delta):
```


A more detailed example of a LPT simulation can be found in the [jaxdecomp_lpt](../examples/jaxdecomp_lpt.py).
A more detailed example of a LPT simulation can be found in the [jaxdecomp_lpt](https://github.com/DifferentiableUniverseInitiative/jaxDecomp/blob/main/examples/lpt_nbody_demo.py).


### Benchmarks

The performance benchmarks for jaxDecomp were conducted on the Jean Zay supercomputer using the A100 GPUs (each with 80 GB of memory). These tests demonstrate the scalability and efficiency of jaxDecomp in handling large-scale FFT operations.

The tests show that jaxDecomp scales very well, even when distributed across multiple nodes, maintaining efficient performance. In particular, jaxDecomp demonstrates efficient computation as the number of GPUs and problem size increase.

![Strong Scaling Performance of jaxDecomp](assets/strong_scaling.png)

# Benchmark

### TO REDO (ADD BENCHMARKS VS DISTRIBUTED JAX)

We benchmarked the distributed FFTs using `jaxDecomp` on a V100s with 32GB of memory. We compared the performance of `jaxDecomp` with the base `JAX` implementation.\
At $2048^3$ resolution, the base `JAX` implementation could not fit the data on a single GPU, while `jaxDecomp` could fit the data on 4 GPUs.
![Weak Scaling Performance of jaxDecomp](assets/weak_scaling.png)

![Performance comparison between JAX and jaxDecomp](assets/benchmark.png){.center width=40%}

# Stability and releases

Expand Down

0 comments on commit 1396d48

Please sign in to comment.