Skip to content

Commit

Permalink
different approach: use Ptr to Vector and native +
Browse files Browse the repository at this point in the history
  • Loading branch information
benegee committed Sep 5, 2024
1 parent c5970f3 commit 6140e98
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 12 deletions.
1 change: 0 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ jobs:
- uses: julia-actions/julia-buildpkg@v1
env:
PYTHON: ""
- run: julia --project=@. -e 'using Pkg; pkg"add MPI#vc/custom_ops"'
- name: Run tests without coverage
uses: julia-actions/julia-runtest@v1
with:
Expand Down
8 changes: 0 additions & 8 deletions src/auxiliary/mpi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,3 @@ parallel execution of Trixi.jl.
See the "Miscellaneous" section of the [documentation](https://docs.sciml.ai/DiffEqDocs/stable/basics/common_solver_opts/).
"""
ode_unstable_check(dt, u, semi, t) = isnan(dt)

# Custom MPI operators to work around
# https://github.com/trixi-framework/Trixi.jl/issues/1922
function reduce_vector_plus(x, y)
x .+ y
end
MPI.@Op(reduce_vector_plus, SVector{4, Float64})
MPI.@Op(reduce_vector_plus, SVector{5, Float64})
15 changes: 12 additions & 3 deletions src/callbacks_step/analysis_dg2d_parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,19 @@ function integrate_via_indices(func::Func, u,
normalize = normalize)

# OBS! Global results are only calculated on MPI root, all other domains receive `nothing`
global_integral = MPI.Reduce!(Ref(local_integral), reduce_vector_plus, mpi_root(),
mpi_comm())
if local_integral isa Real
global_integral = MPI.Reduce!(Ref(local_integral), +, mpi_root(), mpi_comm())
else
global_integral = MPI.Reduce!(Base.unsafe_convert(Ptr{Float64}, Ref(local_integral)), +, mpi_root(), mpi_comm())
end

if mpi_isroot()
integral = convert(typeof(local_integral), global_integral[])
if local_integral isa Real
integral = global_integral[]
else
global_wrapped = unsafe_wrap(Array, global_integral, length(local_integral))
integral = convert(typeof(local_integral), global_wrapped)
end
else
integral = convert(typeof(local_integral), NaN * local_integral)
end
Expand Down

0 comments on commit 6140e98

Please sign in to comment.