Skip to content

Commit

Permalink
Merge pull request #2169 from CliMA/gb/distributed
Browse files Browse the repository at this point in the history
Fix distributed remapping bug
  • Loading branch information
Sbozzolo authored Feb 15, 2025
2 parents 3006168 + 5d07152 commit 6838f81
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 141 deletions.
35 changes: 5 additions & 30 deletions .github/workflows/JuliaFormatter.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,34 +7,9 @@ on:

jobs:
format:
runs-on: ubuntu-24.04
timeout-minutes: 30
runs-on: ubuntu-latest
steps:
- name: Cancel Previous Runs
uses: styfle/[email protected]
with:
access_token: ${{ github.token }}

- uses: actions/checkout@v4

- uses: dorny/[email protected]
id: filter
with:
filters: |
julia_file_change:
- added|modified: '**.jl'
- uses: julia-actions/setup-julia@v2
if: steps.filter.outputs.julia_file_change == 'true'
with:
version: '1.10'

- name: Apply JuliaFormatter
if: steps.filter.outputs.julia_file_change == 'true'
run: |
julia --color=yes --project=.dev .dev/climaformat.jl --verbose .
- name: Check formatting diff
if: steps.filter.outputs.julia_file_change == 'true'
run: |
git diff --color=always --exit-code
- uses: julia-actions/julia-format@v3
with:
version: '1' # Set `version` to '1.0.54' if you need to use JuliaFormatter.jl v1.0.54 (default: '1')
suggestion-label: 'format-suggest' # leave this unset or empty to show suggestions for all PRs
6 changes: 4 additions & 2 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ main
-------

- Prior to this version, `CommonSpaces` could not be created with
`ClimaComms.MPICommContext`. This is now fixed with PR
`ClimaComms.MPICommsContext`. This is now fixed with PR
[2176](https://github.com/CliMA/ClimaCore.jl/pull/2176).

- Fixed bug in distributed remapping with CUDA. Sometimes, `ClimaCore` would not
properly fill the output arrays with the correct values. This is now fixed. PR
[2169](https://github.com/CliMA/ClimaCore.jl/pull/2169)

v0.14.24
-------
Expand Down
125 changes: 29 additions & 96 deletions src/Remapping/distributed_remapping.jl
Original file line number Diff line number Diff line change
Expand Up @@ -744,29 +744,6 @@ function _reset_interpolated_values!(remapper::Remapper)
fill!(remapper._interpolated_values, 0)
end

"""
_collect_and_return_interpolated_values!(remapper::Remapper,
num_fields::Int)
Perform an MPI call to aggregate the interpolated points from all the MPI processes and save
the result in the local state of the `remapper`. Only the root process will return the
interpolated data.
`_collect_and_return_interpolated_values!` is type-unstable and allocates new return arrays.
`num_fields` is the number of fields that have been interpolated in this batch.
"""
function _collect_and_return_interpolated_values!(
remapper::Remapper,
num_fields::Int,
)
return ClimaComms.reduce(
remapper.comms_ctx,
remapper._interpolated_values[remapper.colons..., 1:num_fields],
+,
)
end

function _collect_interpolated_values!(
dest,
remapper::Remapper,
Expand All @@ -777,38 +754,26 @@ function _collect_interpolated_values!(
if only_one_field
ClimaComms.reduce!(
remapper.comms_ctx,
remapper._interpolated_values[remapper.colons..., begin],
view(remapper._interpolated_values, remapper.colons..., 1),
dest,
+,
)
return nothing
else
num_fields = 1 + index_field_end - index_field_begin
ClimaComms.reduce!(
remapper.comms_ctx,
view(
remapper._interpolated_values,
remapper.colons...,
1:num_fields,
),
view(dest, remapper.colons..., index_field_begin:index_field_end),
+,
)
end

num_fields = 1 + index_field_end - index_field_begin

ClimaComms.reduce!(
remapper.comms_ctx,
view(remapper._interpolated_values, remapper.colons..., 1:num_fields),
view(dest, remapper.colons..., index_field_begin:index_field_end),
+,
)

return nothing
end

"""
batched_ranges(num_fields, buffer_length)
Partition the indices from 1 to num_fields in such a way that no range is larger than
buffer_length.
"""
function batched_ranges(num_fields, buffer_length)
return [
(i * buffer_length + 1):(min((i + 1) * buffer_length, num_fields)) for
i in 0:(div((num_fields - 1), buffer_length))
]
end

"""
interpolate(remapper::Remapper, fields)
interpolate!(dest, remapper::Remapper, fields)
Expand Down Expand Up @@ -860,58 +825,21 @@ int12 = interpolate(remapper, [field1, field2])
```
"""
function interpolate(remapper::Remapper, fields)

ArrayType = ClimaComms.array_type(remapper.space)
FT = Spaces.undertype(remapper.space)
only_one_field = fields isa Fields.Field
if only_one_field
fields = [fields]
end

for field in fields
axes(field) == remapper.space ||
error("Field is defined on a different space than remapper")
end
interpolated_values_dim..., _buffer_length =
size(remapper._interpolated_values)

isa_vertical_space = remapper.space isa Spaces.FiniteDifferenceSpace

index_field_begin, index_field_end =
1, min(length(fields), remapper.buffer_length)

# Partition the indices in such a way that nothing is larger than
# buffer_length
index_ranges = batched_ranges(length(fields), remapper.buffer_length)
allocate_extra = only_one_field ? () : (length(fields),)
dest = ArrayType(zeros(FT, interpolated_values_dim..., allocate_extra...))

cat_fn = (l...) -> cat(l..., dims = length(remapper.colons) + 1)

interpolated_values = mapreduce(cat_fn, index_ranges) do range
num_fields = length(range)

# Reset interpolated_values. This is needed because we collect distributed results
# with a + reduction.
_reset_interpolated_values!(remapper)
# Perform the interpolations (horizontal and vertical)
_set_interpolated_values!(
remapper,
view(fields, index_field_begin:index_field_end),
)

if !isa_vertical_space
# For spaces with an horizontal component, reshape the output so that it is a nice grid.
_apply_mpi_bitmask!(remapper, num_fields)
else
# For purely vertical spaces, just move to _interpolated_values
remapper._interpolated_values .= remapper._local_interpolated_values
end

# Finally, we have to send all the _interpolated_values to root and sum them up to
# obtain the final answer. Only the root will contain something useful.
return _collect_and_return_interpolated_values!(remapper, num_fields)
end

# Non-root processes
isnothing(interpolated_values) && return nothing

return only_one_field ? interpolated_values[remapper.colons..., begin] :
interpolated_values
# interpolate! has an MPI call, so it is important to return after it is
# called, not before!
interpolate!(dest, remapper, fields)
ClimaComms.iamroot(remapper.comms_ctx) || return nothing
return dest
end

# dest has to be allowed to be nothing because interpolation happens only on the root
Expand All @@ -927,6 +855,11 @@ function interpolate!(
end
isa_vertical_space = remapper.space isa Spaces.FiniteDifferenceSpace

for field in fields
axes(field) == remapper.space ||
error("Field is defined on a different space than remapper")
end

if !isnothing(dest)
# !isnothing(dest) means that this is the root process, in this case, the size have
# to match (ignoring the buffer_length)
Expand Down
15 changes: 2 additions & 13 deletions test/Remapping/distributed_remapping.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,6 @@ atexit() do
global_logger(prev_logger)
end

@testset "Utils" begin
# batched_ranges(num_fields, buffer_length)
@test Remapping.batched_ranges(1, 1) == [1:1]
@test Remapping.batched_ranges(1, 2) == [1:1]
@test Remapping.batched_ranges(2, 2) == [1:2]
@test Remapping.batched_ranges(3, 2) == [1:2, 3:3]
end

with_mpi = context isa ClimaComms.MPICommsContext

@testset "2D extruded" begin
Expand Down Expand Up @@ -161,10 +153,7 @@ end

quad = Quadratures.GLL{4}()
horzmesh = Meshes.RectilinearMesh(horzdomain, 10, 10)
horztopology = Topologies.Topology2D(
ClimaComms.SingletonCommsContext(device),
horzmesh,
)
horztopology = Topologies.Topology2D(context, horzmesh)
horzspace = Spaces.SpectralElementSpace2D(horztopology, quad)

hv_center_space =
Expand Down Expand Up @@ -330,7 +319,7 @@ end
quad = Quadratures.GLL{4}()
horzmesh = Meshes.RectilinearMesh(horzdomain, 10, 10)
horztopology = Topologies.Topology2D(
ClimaComms.SingletonCommsContext(device),
context,
horzmesh,
Topologies.spacefillingcurve(horzmesh),
)
Expand Down

0 comments on commit 6838f81

Please sign in to comment.