Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-host distribution #3

Open
EiffL opened this issue Jan 21, 2022 · 18 comments
Open

Multi-host distribution #3

EiffL opened this issue Jan 21, 2022 · 18 comments

Comments

@EiffL
Copy link

EiffL commented Jan 21, 2022

In fantastic news, after over a year of waiting and checking every few months if it was working yet, it looks like finally it's possible to instantiate a distributed XLA runtime in Jax, which means.... Native access to NCCL collectives and composable parallelisation with pmap and xmap!!!

Demo of how to allocate 16 GPUs accross 4 nodes on Perlmutter here: https://github.com/EiffL/jax-gpu-cluster-demo

I'll be testing these things out and documenting my finding in this issue. Maybe won't be directly useful at first but at some point down the line we want to be able to run very large sims easily.

@EiffL
Copy link
Author

EiffL commented Jan 22, 2022

ok.... well.... either it's black magic, or there is something I don't understand, but in any case my mind is blown....

mesh_shape = (2,) # On 2 GPUs
devices = np.asarray(jax.devices()).reshape(*mesh_shape)
mesh = maps.Mesh(devices, ('x'))

parallel_za = pjit(
  lambda x: pm.generate_za([128,128,128], x, cosmo, dyn_conf, stat_conf).dm,
  in_axis_resources=PartitionSpec('x', None, None),
  out_axis_resources=PartitionSpec('x', None))

with maps.mesh(mesh.devices, mesh.axis_names):
 data = parallel_za(init_cond)

appears to be all it takes to distribute accross multiple devices.... But it returns the correct result...

I'm very puzzled by this.... in order to perform this operation it has to do a bunch of things like performing an fft over the distributed initial_cond field (which is split in 2 accross the first dimension), which I can imagine, but then it needs to compute the displacement over of 2 batches of particules, I'm really not sure how particules in process 2 get to know about the density stored by process 1, unless... it internally "undistribute" the data at some point which would defeat the purpose.... or it has to be sufficiently smart to devise a communication strategy to retrieve needed data....

@EiffL
Copy link
Author

EiffL commented Nov 21, 2022

@eelregit Making progress on this ^^
jaxdecomp is now able to do forward and backward FFTs https://github.com/DifferentiableUniverseInitiative/jaxDecomp

Still have to add a few things, but really not far away from being usable as part of pmwd

@EiffL
Copy link
Author

EiffL commented Nov 25, 2022

Ok, I've added halo exchange and cleaned up the interface. Also added gradients of these operations. You can also select which backend you want to use, MPI, NCCL, or NVSHMEM. As far as I can tell, this should be strictly superior to the cufftMP library, although now that I know how to do these bindings, it would be even easier to use cufftMP.

Here is how you do a 3D FFT distributed on many GPUs with the current version of the API:

from mpi4py import MPI
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()

import jax
import jax.numpy as jnp
import jaxdecomp

# Initialise the library, and optionally selects a communication backend (defaults to NCCL)
jaxdecomp.init()
jaxdecomp.config.update('halo_comm_backend', jaxdecomp.HALO_COMM_MPI)
jaxdecomp.config.update('transpose_comm_backend', jaxdecomp.TRANSPOSE_COMM_MPI_A2A)

# Setup a processor mesh (should be same size as "size")
pdims= [2,4]
global_shape=[1024,1024,1024]

# Initialize an array with the expected gobal size
array = jax.random.normal(shape=[1024//pdims[1], 
                                 1024//pdims[0], 
                                 1024], 
            key=jax.random.PRNGKey(rank)).astype('complex64')

# Forward FFT, note that the output FFT is transposed
karray = jaxdecomp.pfft3d(array, 
                global_shape=global_shape, pdims=pdims)

# Reverse FFT
recarray = jaxdecomp.ipfft3d(karray, 
        global_shape=global_shape, pdims=pdims)
        
# Add halo regions to our array
padded_array = jnp.pad(array, [(32,32),(32,32),(32,32)])
# Perform a halo exchange
padded_array = jaxdecomp.halo_exchange(padded_array,
                                       halo_extents=(32,32,32),
                                       halo_periods=(True,True,True),
                                       pdims=pdims,
                                       global_shape=global_shape)

Compiling is unfortunately not 100% trivial because it depends a lot on the local environment of the cluster, so I haven't managed to fully automatize it yet....

@EiffL
Copy link
Author

EiffL commented Nov 25, 2022

One 1024^3 FFT on 4 V100 GPUs... 0.5 ms 🤣

@EiffL
Copy link
Author

EiffL commented Nov 25, 2022

Annnnd 50ms for a 2048^3 FFT on 16 V100 GPUs on 2 nodes... 🤯 (also tagging @modichirag )

@eelregit
Copy link
Owner

eelregit commented Nov 25, 2022

Thanks! This is a lot of progress.

One 1024^3 FFT on 4 V100 GPUs... 0.5 ms rofl

Wow. But doesn't 0.5 ms sound too fast, like faster than the memory bandwidth on 1 GPU?

Annnnd 50ms for a 2048^3 FFT on 16 V100 GPUs on 2 nodes

Should we expect weak scaling here, in which case (a bit more than) 2x the 1024^3 timing?

@EiffL
Copy link
Author

EiffL commented Nov 25, 2022

0.5ms does sound really fast, but the result of the FFT seems to be correct, so.... maybe?

I'm not 100% sure what scaling we should expect, as a function of message size it's possible that the cost is not the same, as the backend might switch between different strategies.

Also, interesting note, this is using the NCCL backend, if I use the MPI backend on this setting and hardware I get 6s. I guess it will be very hardware and problem dependent, but that's what's nice with cuDecomp it includes an autotuning tool, that will find the best distribution strategy and backend for given hardware and problem size (which I havent interfaced in jax yet, but is there at the C++ level)

@eelregit
Copy link
Owner

eelregit commented Nov 25, 2022

Okay, 50ms seems to be comparable to what cuFFTMp showed in https://developer.nvidia.com/blog/multinode-multi-gpu-using-nvidia-cufftmp-ffts-at-scale/

For me one 1024^3 FFT on 1 A100 seems to take more than 30ms.
So... are you sure that 0.5ms on 4 V100 is not underestimating?

@EiffL
Copy link
Author

EiffL commented Nov 25, 2022

oups ^^' you are right, I didnt include a block until ready...

New timings on V100s:

  • 2048^3 on 16 GPUs using NCCL: 8s
  • 1024^3 on 4 GPUs using NCCL: 0.2s

This is probably a lot more reasonable

@eelregit
Copy link
Owner

Still very promising. I wonder if the difference in performance and scaling is mainly from the hardware (nvlink, nvswitch etc)?

@wendazhou
Copy link

Nvidia claims that they can achieve close to perfect weak scaling with cuFFTmp in the 2^30 elements / GPU range (up to about 4k GPU), but I know that that library leans heavily on nvshmem communication to achieve optimal overlapping. nvidia claims that their cluster can do 2048^3 in ~100ms on 16 GPUs (albeit A100), so it definitely might be worth looking into using that library directly / setting up the correct hardware config for cuFFTmp and cuDECOMP.

@EiffL
Copy link
Author

EiffL commented Nov 28, 2022

Yep, it should be trivial to add an op for cufftmp in jaxdecomp as it's already part of the nvhpc SDK, so no need to compile an external library :-) I'm traveling this week, but if I catch a little quiet time I'll add the option to use cufftMP. Unless you want to have a go at it @wendazhou ;-)

One thing though that I thought about, reading the documentation it looks like nvshmem memory needs to be allocated in a particular fashion, which is different from a standard cuda device memory. That means we can't directly use the input/output buffers allocated by XLA, and that the op will need to allocate its own buffers using nvshmem, that will kind of double the memory needed to run the FFT.

In my current implementation, I do the transform in-place within the input/buffer allocated by XLA. I also need a workspace buffer, of size determined by cudecomp (I think something like twice the size of the input array) , also allocated by XLA.

@wendazhou
Copy link

I'm also travelling for NeurIPS this week, probably won't have time to look at it. For the memory, I don't think that input / output buffers need to be allocated using special functions, only the scratch memory itself, which cuFFTmp handles internally (but this indeed requires linear memory in the transform size, see doc).

I think the main work in addition to plumbing everything together will be to figure out how to describe the data layout correctly doc

@eelregit
Copy link
Owner

eelregit commented Feb 3, 2023

@hai4john

@EiffL
Copy link
Author

EiffL commented May 15, 2024

@eelregit ... ok, took about a year, but it's now working nicely with the latest version of JAX thanks to the heroic efforts of my collaborator @ASKabalan :-)

We have updated jaxDecomp https://github.com/DifferentiableUniverseInitiative/jaxDecomp to be compatible with the native JAX distribution, and we have a rough prototype of a distributed LPT demo here: DifferentiableUniverseInitiative/JaxPM#19 (comment)

We have been able to scale it to 24 GPUs (the max I can allocated at Flatiron), and to give you an idea, it executes an LPT simulation (no PM iteration) on a 2048^3 mesh in 4.7s. We haven't yet carefully profiled the execution, it's not impossible we could do even better, but at least it means we can do reasonable cosmological volumes in a matter of seconds:

image (17)

With @ASKabalan we are probably going to integrate this in JaxPM, as a way to prototype user APIs compatible with distribution (it's not completely trivial, we don't want to bother users with a lot of stuff if they are running on a single GPU setting).
But I still have in mind to push distribution to pmwd, so wanted to check-in with you to see what you think and if you have another path to distribution in mind already.

@eelregit
Copy link
Owner

Thanks @EiffL for letting me know. Let's find a time (maybe next week) to chat via email?

@EiffL
Copy link
Author

EiffL commented Jul 18, 2024

@eelregit here is a minimal demo of LPT implemented using jaxdecomp: https://github.com/DifferentiableUniverseInitiative/jaxDecomp/blob/main/examples/lpt_nbody_demo.py

@EiffL
Copy link
Author

EiffL commented Jul 18, 2024

@eelregit timings https://flanusse.net/talks/Split2024/#/15/0/1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants