Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support MPI #752

Open
wants to merge 47 commits into
base: main
Choose a base branch
from
Open

Support MPI #752

wants to merge 47 commits into from

Conversation

mofeing
Copy link
Collaborator

@mofeing mofeing commented Feb 15, 2025

This PR...

  • Registers MPI routine symbol address when MPI.jl gets loaded
  • Specializes MPI.jl methods to be traced by Reactant

unresolved questions

  • how can we represent MPI_Request with tensor and stablehlo types?
  • mmm stablehlo.custom_call has a backend attribute that could be useful during lowering; e.g. if we want to lower to NCCL instead of MPI, since both have a similar API, we could potentially add our own custom c-functions that use NCCL but adapt them to MPI-like API
  • @wsmoses can we create @cfunctions in Julia and pass them to the symbol table? some MPI routines might need a lil bit of adaption and writing them in Julia would be easier, faster (and also, would use the correct symbols from MPI.jl-loaded libmpi)

tested

to do

  • MPI communicators
  • sharding
  • more MPI routines
  • custom reduction operators

cc @JBlaschke @hhkit

@wsmoses
Copy link
Member

wsmoses commented Feb 15, 2025

you won't, instead you'll emit something like


function send_wrap(%arg : memref<axb>) {
    mpi.send %arg
}

function main() {
    ...
    enzymexla.jit_call @set_wrap(%x : tensor<...>)
}

And then lower-jit will convert into a custom call. however you will need to define a lowering of mpi.send into a corresponding MPI_Send call [which will use the symbol you just registered here]

Re CUDA though we also need to ensure we are sync'd wrt the current custream which you can get via enzymexla.get_stream

@mofeing
Copy link
Collaborator Author

mofeing commented Feb 16, 2025

mmm from our last discussion on this a couple of weeks ago, i understood that we would emit this

function main() {
    ...
    mpi.send(%arg0, ...)
    ...
}

and it would get lowered to

function send_wrap(%arg : memref<axb>) {
    llvm.call <0xffff> (%arg)
}

function main() {
    ...
    enzymexla.jit_call @send_wrap(%x : tensor<...>)
    ...
}

which will finally lower to the following with the enzymexla.jit pass

function main() {
    ...
    stablehlo.custom_call @mpi_send_wrap(%x : tensor<...>)
    ...
}

is this correct or do we need to emit the enzymexla.jit_call directly from Reactant?

ahh or do you mean that any wrapping we need to do around MPI should be done in this way?

Re CUDA though we also need to ensure we are sync'd wrt the current custream which you can get via enzymexla.get_stream

okay, this will probably be required for NCCL

@mofeing mofeing force-pushed the ss/mpi branch 2 times, most recently from 2744d58 to 19c0eca Compare March 16, 2025 07:33
mofeing and others added 4 commits March 16, 2025 09:08
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@mofeing mofeing requested a review from wsmoses March 16, 2025 09:06
@mofeing mofeing marked this pull request as ready for review March 16, 2025 09:06
@mofeing
Copy link
Collaborator Author

mofeing commented Mar 16, 2025

The PR is ready for review. The missing MPI routines are waiting to other PRs or need some fix, but they can be added later.

@giordano MPI testset result gets printed multiple times due to them being run on all ranks but I guess is not a problem?

Copy link
Member

@giordano giordano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MPI testset result gets printed multiple times due to them being run on all ranks but I guess is not a problem?

That's unfortunate. Just don't use @testset? We don't in https://github.com/JuliaParallel/MPI.jl/tree/5ef7fef6d6c3e2ab2ad380f346c77235f47213bf/test

@safetestset "MPI" begin
using MPI
nranks = 2
run(`$(mpiexec()) -n $nranks $(Base.julia_cmd()) integration/mpi.jl`)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
run(`$(mpiexec()) -n $nranks $(Base.julia_cmd()) integration/mpi.jl`)
run(`$(mpiexec()) -n $nranks $(Base.julia_cmd()) --startup-file=no $(joinpath(@__DIR__, "integration", "mpi.jl")`)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants