-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multi-host distribution #3
Comments
ok.... well.... either it's black magic, or there is something I don't understand, but in any case my mind is blown.... mesh_shape = (2,) # On 2 GPUs
devices = np.asarray(jax.devices()).reshape(*mesh_shape)
mesh = maps.Mesh(devices, ('x'))
parallel_za = pjit(
lambda x: pm.generate_za([128,128,128], x, cosmo, dyn_conf, stat_conf).dm,
in_axis_resources=PartitionSpec('x', None, None),
out_axis_resources=PartitionSpec('x', None))
with maps.mesh(mesh.devices, mesh.axis_names):
data = parallel_za(init_cond) appears to be all it takes to distribute accross multiple devices.... But it returns the correct result... I'm very puzzled by this.... in order to perform this operation it has to do a bunch of things like performing an fft over the distributed initial_cond field (which is split in 2 accross the first dimension), which I can imagine, but then it needs to compute the displacement over of 2 batches of particules, I'm really not sure how particules in process 2 get to know about the density stored by process 1, unless... it internally "undistribute" the data at some point which would defeat the purpose.... or it has to be sufficiently smart to devise a communication strategy to retrieve needed data.... |
@eelregit Making progress on this ^^ Still have to add a few things, but really not far away from being usable as part of pmwd |
Ok, I've added halo exchange and cleaned up the interface. Also added gradients of these operations. You can also select which backend you want to use, MPI, NCCL, or NVSHMEM. As far as I can tell, this should be strictly superior to the cufftMP library, although now that I know how to do these bindings, it would be even easier to use cufftMP. Here is how you do a 3D FFT distributed on many GPUs with the current version of the API: from mpi4py import MPI
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()
import jax
import jax.numpy as jnp
import jaxdecomp
# Initialise the library, and optionally selects a communication backend (defaults to NCCL)
jaxdecomp.init()
jaxdecomp.config.update('halo_comm_backend', jaxdecomp.HALO_COMM_MPI)
jaxdecomp.config.update('transpose_comm_backend', jaxdecomp.TRANSPOSE_COMM_MPI_A2A)
# Setup a processor mesh (should be same size as "size")
pdims= [2,4]
global_shape=[1024,1024,1024]
# Initialize an array with the expected gobal size
array = jax.random.normal(shape=[1024//pdims[1],
1024//pdims[0],
1024],
key=jax.random.PRNGKey(rank)).astype('complex64')
# Forward FFT, note that the output FFT is transposed
karray = jaxdecomp.pfft3d(array,
global_shape=global_shape, pdims=pdims)
# Reverse FFT
recarray = jaxdecomp.ipfft3d(karray,
global_shape=global_shape, pdims=pdims)
# Add halo regions to our array
padded_array = jnp.pad(array, [(32,32),(32,32),(32,32)])
# Perform a halo exchange
padded_array = jaxdecomp.halo_exchange(padded_array,
halo_extents=(32,32,32),
halo_periods=(True,True,True),
pdims=pdims,
global_shape=global_shape) Compiling is unfortunately not 100% trivial because it depends a lot on the local environment of the cluster, so I haven't managed to fully automatize it yet.... |
One 1024^3 FFT on 4 V100 GPUs... 0.5 ms 🤣 |
Annnnd 50ms for a 2048^3 FFT on 16 V100 GPUs on 2 nodes... 🤯 (also tagging @modichirag ) |
Thanks! This is a lot of progress.
Wow. But doesn't 0.5 ms sound too fast, like faster than the memory bandwidth on 1 GPU?
Should we expect weak scaling here, in which case (a bit more than) 2x the 1024^3 timing? |
0.5ms does sound really fast, but the result of the FFT seems to be correct, so.... maybe? I'm not 100% sure what scaling we should expect, as a function of message size it's possible that the cost is not the same, as the backend might switch between different strategies. Also, interesting note, this is using the NCCL backend, if I use the MPI backend on this setting and hardware I get 6s. I guess it will be very hardware and problem dependent, but that's what's nice with cuDecomp it includes an autotuning tool, that will find the best distribution strategy and backend for given hardware and problem size (which I havent interfaced in jax yet, but is there at the C++ level) |
Okay, 50ms seems to be comparable to what cuFFTMp showed in https://developer.nvidia.com/blog/multinode-multi-gpu-using-nvidia-cufftmp-ffts-at-scale/ For me one 1024^3 FFT on 1 A100 seems to take more than 30ms. |
oups ^^' you are right, I didnt include a block until ready... New timings on V100s:
This is probably a lot more reasonable |
Still very promising. I wonder if the difference in performance and scaling is mainly from the hardware (nvlink, nvswitch etc)? |
Nvidia claims that they can achieve close to perfect weak scaling with cuFFTmp in the 2^30 elements / GPU range (up to about 4k GPU), but I know that that library leans heavily on |
Yep, it should be trivial to add an op for cufftmp in jaxdecomp as it's already part of the nvhpc SDK, so no need to compile an external library :-) I'm traveling this week, but if I catch a little quiet time I'll add the option to use cufftMP. Unless you want to have a go at it @wendazhou ;-) One thing though that I thought about, reading the documentation it looks like nvshmem memory needs to be allocated in a particular fashion, which is different from a standard cuda device memory. That means we can't directly use the input/output buffers allocated by XLA, and that the op will need to allocate its own buffers using nvshmem, that will kind of double the memory needed to run the FFT. In my current implementation, I do the transform in-place within the input/buffer allocated by XLA. I also need a workspace buffer, of size determined by cudecomp (I think something like twice the size of the input array) , also allocated by XLA. |
I'm also travelling for NeurIPS this week, probably won't have time to look at it. For the memory, I don't think that input / output buffers need to be allocated using special functions, only the scratch memory itself, which I think the main work in addition to plumbing everything together will be to figure out how to describe the data layout correctly doc |
@eelregit ... ok, took about a year, but it's now working nicely with the latest version of JAX thanks to the heroic efforts of my collaborator @ASKabalan :-) We have updated jaxDecomp https://github.com/DifferentiableUniverseInitiative/jaxDecomp to be compatible with the native JAX distribution, and we have a rough prototype of a distributed LPT demo here: DifferentiableUniverseInitiative/JaxPM#19 (comment) We have been able to scale it to 24 GPUs (the max I can allocated at Flatiron), and to give you an idea, it executes an LPT simulation (no PM iteration) on a 2048^3 mesh in 4.7s. We haven't yet carefully profiled the execution, it's not impossible we could do even better, but at least it means we can do reasonable cosmological volumes in a matter of seconds: With @ASKabalan we are probably going to integrate this in JaxPM, as a way to prototype user APIs compatible with distribution (it's not completely trivial, we don't want to bother users with a lot of stuff if they are running on a single GPU setting). |
Thanks @EiffL for letting me know. Let's find a time (maybe next week) to chat via email? |
@eelregit here is a minimal demo of LPT implemented using jaxdecomp: https://github.com/DifferentiableUniverseInitiative/jaxDecomp/blob/main/examples/lpt_nbody_demo.py |
In fantastic news, after over a year of waiting and checking every few months if it was working yet, it looks like finally it's possible to instantiate a distributed XLA runtime in Jax, which means.... Native access to NCCL collectives and composable parallelisation with pmap and xmap!!!
Demo of how to allocate 16 GPUs accross 4 nodes on Perlmutter here: https://github.com/EiffL/jax-gpu-cluster-demo
I'll be testing these things out and documenting my finding in this issue. Maybe won't be directly useful at first but at some point down the line we want to be able to run very large sims easily.
The text was updated successfully, but these errors were encountered: