diff --git a/docs/src/reference/collective.md b/docs/src/reference/collective.md index 50c8e3dd5..8991bb821 100644 --- a/docs/src/reference/collective.md +++ b/docs/src/reference/collective.md @@ -11,6 +11,7 @@ MPI.Ibarrier ```@docs MPI.Bcast! +MPI.Bcast MPI.bcast ``` diff --git a/src/collective.jl b/src/collective.jl index 20dce82f7..ba535ced5 100644 --- a/src/collective.jl +++ b/src/collective.jl @@ -60,6 +60,19 @@ function Bcast!(data, root::Integer, comm::Comm) Bcast!(Buffer(data), root, comm) end +""" + Bcast(obj, root::Integer, comm::Comm) + +Broadcast the `obj` from `root` to all processes in `comm`. Returns the object. +Currently `obj` must be `isbits`, i.e. `isbitstype(typeof(obj)) == true`. +""" +function Bcast(obj::T, root::Integer, comm::Comm) where T + if !isbitstype(T) + throw(ArgumentError("Bcast currently only supports `isbitstype`s.")) + end + Bcast!(Ref(obj), root, comm)[] +end + """ bcast(obj, comm::Comm; root::Integer=0) diff --git a/test/test_bcast.jl b/test/test_bcast.jl index 837722775..8c45a437c 100644 --- a/test/test_bcast.jl +++ b/test/test_bcast.jl @@ -26,6 +26,29 @@ MPI.Bcast!(B, comm; root=root) @test B == A +# Bcast: number +A = 1.23 +B = MPI.Comm_rank(comm) == root ? 1.23 : 0.0 +res = MPI.Bcast(B, root, comm) +@test typeof(res) == typeof(A) +@test res == A + +# Bcast: scalar struct +struct XY + x::Float64 + y::Float32 +end +A = XY(1.23, 4.56f0) +B = MPI.Comm_rank(comm) == root ? A : XY(0.0, 0.0f0) +res = MPI.Bcast(B, root, comm) +@test typeof(res) == typeof(A) +@test res == A + +# Bcast: array +A = rand(3) +B = MPI.Comm_rank(comm) == root ? A : zeros(3) +@test_throws ArgumentError MPI.Bcast(B, root, comm) + g = x -> x^2 + 2x - 1 if MPI.Comm_rank(comm) == root f = g