From cc99f5f6bc4bca425c7a216e367fcf6989e8d7c5 Mon Sep 17 00:00:00 2001 From: Gabriele Bozzola Date: Fri, 14 Feb 2025 17:18:51 -0800 Subject: [PATCH 1/3] Simplify distributed_remapping.interpolate I spent many hours tracking down https://github.com/CliMA/ClimaCore.jl/issues/2108 and could not find the root issue. I decided to take a different approach and simplify redefine `interpolate` in terms of `interpolate!`. --- NEWS.md | 6 +- src/Remapping/distributed_remapping.jl | 122 ++++++------------------ test/Remapping/distributed_remapping.jl | 15 +-- 3 files changed, 33 insertions(+), 110 deletions(-) diff --git a/NEWS.md b/NEWS.md index fc184ca03e..d397d50fbb 100644 --- a/NEWS.md +++ b/NEWS.md @@ -5,9 +5,11 @@ main ------- - Prior to this version, `CommonSpaces` could not be created with -`ClimaComms.MPICommContext`. This is now fixed with PR +`ClimaComms.MPICommsContext`. This is now fixed with PR [2176](https://github.com/CliMA/ClimaCore.jl/pull/2176). - +- Fixed bug in distributed remapping with CUDA. Sometimes, `ClimaCore` would not + properly fill the output arrays with the correct values. This is now fixed. PR + [2169](https://github.com/CliMA/ClimaCore.jl/pull/2169) v0.14.24 ------- diff --git a/src/Remapping/distributed_remapping.jl b/src/Remapping/distributed_remapping.jl index 77a0e7eacd..9ed6a593de 100644 --- a/src/Remapping/distributed_remapping.jl +++ b/src/Remapping/distributed_remapping.jl @@ -744,29 +744,6 @@ function _reset_interpolated_values!(remapper::Remapper) fill!(remapper._interpolated_values, 0) end -""" - _collect_and_return_interpolated_values!(remapper::Remapper, - num_fields::Int) - -Perform an MPI call to aggregate the interpolated points from all the MPI processes and save -the result in the local state of the `remapper`. Only the root process will return the -interpolated data. - -`_collect_and_return_interpolated_values!` is type-unstable and allocates new return arrays. - -`num_fields` is the number of fields that have been interpolated in this batch. -""" -function _collect_and_return_interpolated_values!( - remapper::Remapper, - num_fields::Int, -) - return ClimaComms.reduce( - remapper.comms_ctx, - remapper._interpolated_values[remapper.colons..., 1:num_fields], - +, - ) -end - function _collect_interpolated_values!( dest, remapper::Remapper, @@ -774,6 +751,8 @@ function _collect_interpolated_values!( index_field_end::Int; only_one_field, ) + # NOTE: MPI barriers for #2108 + ClimaComms.barrier(remapper.comms_ctx) if only_one_field ClimaComms.reduce!( remapper.comms_ctx, @@ -781,34 +760,19 @@ function _collect_interpolated_values!( dest, +, ) - return nothing + else + num_fields = 1 + index_field_end - index_field_begin + ClimaComms.reduce!( + remapper.comms_ctx, + view(remapper._interpolated_values, remapper.colons..., 1:num_fields), + view(dest, remapper.colons..., index_field_begin:index_field_end), + +, + ) end - - num_fields = 1 + index_field_end - index_field_begin - - ClimaComms.reduce!( - remapper.comms_ctx, - view(remapper._interpolated_values, remapper.colons..., 1:num_fields), - view(dest, remapper.colons..., index_field_begin:index_field_end), - +, - ) - + ClimaComms.barrier(remapper.comms_ctx) return nothing end -""" - batched_ranges(num_fields, buffer_length) - -Partition the indices from 1 to num_fields in such a way that no range is larger than -buffer_length. -""" -function batched_ranges(num_fields, buffer_length) - return [ - (i * buffer_length + 1):(min((i + 1) * buffer_length, num_fields)) for - i in 0:(div((num_fields - 1), buffer_length)) - ] -end - """ interpolate(remapper::Remapper, fields) interpolate!(dest, remapper::Remapper, fields) @@ -860,58 +824,21 @@ int12 = interpolate(remapper, [field1, field2]) ``` """ function interpolate(remapper::Remapper, fields) - + ArrayType = ClimaComms.array_type(remapper.space) + FT = Spaces.undertype(remapper.space) only_one_field = fields isa Fields.Field - if only_one_field - fields = [fields] - end - - for field in fields - axes(field) == remapper.space || - error("Field is defined on a different space than remapper") - end - - isa_vertical_space = remapper.space isa Spaces.FiniteDifferenceSpace - index_field_begin, index_field_end = - 1, min(length(fields), remapper.buffer_length) + interpolated_values_dim..., _buffer_length = + size(remapper._interpolated_values) - # Partition the indices in such a way that nothing is larger than - # buffer_length - index_ranges = batched_ranges(length(fields), remapper.buffer_length) + allocate_extra = only_one_field ? () : (length(fields),) + dest = ArrayType(zeros(FT, interpolated_values_dim..., allocate_extra...)) - cat_fn = (l...) -> cat(l..., dims = length(remapper.colons) + 1) - - interpolated_values = mapreduce(cat_fn, index_ranges) do range - num_fields = length(range) - - # Reset interpolated_values. This is needed because we collect distributed results - # with a + reduction. - _reset_interpolated_values!(remapper) - # Perform the interpolations (horizontal and vertical) - _set_interpolated_values!( - remapper, - view(fields, index_field_begin:index_field_end), - ) - - if !isa_vertical_space - # For spaces with an horizontal component, reshape the output so that it is a nice grid. - _apply_mpi_bitmask!(remapper, num_fields) - else - # For purely vertical spaces, just move to _interpolated_values - remapper._interpolated_values .= remapper._local_interpolated_values - end - - # Finally, we have to send all the _interpolated_values to root and sum them up to - # obtain the final answer. Only the root will contain something useful. - return _collect_and_return_interpolated_values!(remapper, num_fields) - end - - # Non-root processes - isnothing(interpolated_values) && return nothing - - return only_one_field ? interpolated_values[remapper.colons..., begin] : - interpolated_values + # interpolate! has an MPI call, so it is important to return after it is + # called, not before! + interpolate!(dest, remapper, fields) + ClimaComms.iamroot(remapper.comms_ctx) || return nothing + return dest end # dest has to be allowed to be nothing because interpolation happens only on the root @@ -927,6 +854,11 @@ function interpolate!( end isa_vertical_space = remapper.space isa Spaces.FiniteDifferenceSpace + for field in fields + axes(field) == remapper.space || + error("Field is defined on a different space than remapper") + end + if !isnothing(dest) # !isnothing(dest) means that this is the root process, in this case, the size have # to match (ignoring the buffer_length) diff --git a/test/Remapping/distributed_remapping.jl b/test/Remapping/distributed_remapping.jl index 64f97f4423..0d3ffcd34c 100644 --- a/test/Remapping/distributed_remapping.jl +++ b/test/Remapping/distributed_remapping.jl @@ -31,14 +31,6 @@ atexit() do global_logger(prev_logger) end -@testset "Utils" begin - # batched_ranges(num_fields, buffer_length) - @test Remapping.batched_ranges(1, 1) == [1:1] - @test Remapping.batched_ranges(1, 2) == [1:1] - @test Remapping.batched_ranges(2, 2) == [1:2] - @test Remapping.batched_ranges(3, 2) == [1:2, 3:3] -end - with_mpi = context isa ClimaComms.MPICommsContext @testset "2D extruded" begin @@ -161,10 +153,7 @@ end quad = Quadratures.GLL{4}() horzmesh = Meshes.RectilinearMesh(horzdomain, 10, 10) - horztopology = Topologies.Topology2D( - ClimaComms.SingletonCommsContext(device), - horzmesh, - ) + horztopology = Topologies.Topology2D(context, horzmesh) horzspace = Spaces.SpectralElementSpace2D(horztopology, quad) hv_center_space = @@ -330,7 +319,7 @@ end quad = Quadratures.GLL{4}() horzmesh = Meshes.RectilinearMesh(horzdomain, 10, 10) horztopology = Topologies.Topology2D( - ClimaComms.SingletonCommsContext(device), + context, horzmesh, Topologies.spacefillingcurve(horzmesh), ) From a3c659ae1e6708d80771db81f5211e07798f967e Mon Sep 17 00:00:00 2001 From: Gabriele Bozzola Date: Thu, 6 Feb 2025 11:12:28 -0800 Subject: [PATCH 2/3] Use view instead of making a copy As suggested in https://github.com/JuliaParallel/MPI.jl/issues/892 ``` remapper._interpolated_values[remapper.colons..., begin] ``` allocates a new copy, which can trip up CUDA's synchronization. --- src/Remapping/distributed_remapping.jl | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/Remapping/distributed_remapping.jl b/src/Remapping/distributed_remapping.jl index 9ed6a593de..b6753f8f56 100644 --- a/src/Remapping/distributed_remapping.jl +++ b/src/Remapping/distributed_remapping.jl @@ -751,12 +751,10 @@ function _collect_interpolated_values!( index_field_end::Int; only_one_field, ) - # NOTE: MPI barriers for #2108 - ClimaComms.barrier(remapper.comms_ctx) if only_one_field ClimaComms.reduce!( remapper.comms_ctx, - remapper._interpolated_values[remapper.colons..., begin], + view(remapper._interpolated_values, remapper.colons..., 1), dest, +, ) @@ -764,12 +762,15 @@ function _collect_interpolated_values!( num_fields = 1 + index_field_end - index_field_begin ClimaComms.reduce!( remapper.comms_ctx, - view(remapper._interpolated_values, remapper.colons..., 1:num_fields), + view( + remapper._interpolated_values, + remapper.colons..., + 1:num_fields, + ), view(dest, remapper.colons..., index_field_begin:index_field_end), +, ) end - ClimaComms.barrier(remapper.comms_ctx) return nothing end From 5d071524f08f5e3d7384b72105a4d77ebc3def0e Mon Sep 17 00:00:00 2001 From: Gabriele Bozzola Date: Sat, 15 Feb 2025 07:34:47 -0800 Subject: [PATCH 3/3] Update formatter action --- .github/workflows/JuliaFormatter.yml | 35 ++++------------------------ 1 file changed, 5 insertions(+), 30 deletions(-) diff --git a/.github/workflows/JuliaFormatter.yml b/.github/workflows/JuliaFormatter.yml index 5d32a9bc8c..3fc4059351 100644 --- a/.github/workflows/JuliaFormatter.yml +++ b/.github/workflows/JuliaFormatter.yml @@ -7,34 +7,9 @@ on: jobs: format: - runs-on: ubuntu-24.04 - timeout-minutes: 30 + runs-on: ubuntu-latest steps: - - name: Cancel Previous Runs - uses: styfle/cancel-workflow-action@0.4.0 - with: - access_token: ${{ github.token }} - - - uses: actions/checkout@v4 - - - uses: dorny/paths-filter@v2.9.1 - id: filter - with: - filters: | - julia_file_change: - - added|modified: '**.jl' - - - uses: julia-actions/setup-julia@v2 - if: steps.filter.outputs.julia_file_change == 'true' - with: - version: '1.10' - - - name: Apply JuliaFormatter - if: steps.filter.outputs.julia_file_change == 'true' - run: | - julia --color=yes --project=.dev .dev/climaformat.jl --verbose . - - - name: Check formatting diff - if: steps.filter.outputs.julia_file_change == 'true' - run: | - git diff --color=always --exit-code + - uses: julia-actions/julia-format@v3 + with: + version: '1' # Set `version` to '1.0.54' if you need to use JuliaFormatter.jl v1.0.54 (default: '1') + suggestion-label: 'format-suggest' # leave this unset or empty to show suggestions for all PRs