Skip to content

Commit

Permalink
Implement FD shmem
Browse files Browse the repository at this point in the history
Try dont_limit on recursive resolve_shmem methods

Fixes + more dont limit

Matrix field fixes

Matrix field fixes

DivergenceF2C fix

MatrixField fixes

Qualify DivergenceF2C

wip

Refactor + fixed space bug. All seems good.
  • Loading branch information
charleskawczynski authored and Charlie Kawczynski committed Feb 20, 2025
1 parent 3f0a9d2 commit d807e6e
Show file tree
Hide file tree
Showing 13 changed files with 1,012 additions and 32 deletions.
10 changes: 10 additions & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,16 @@ steps:
key: unit_spec_ops_plane
command: "julia --color=yes --check-bounds=yes --project=.buildkite test/Operators/spectralelement/plane.jl"

- label: "Unit: FD operator (shmem)"
key: unit_fd_operator_shmem
command:
- "julia --color=yes --check-bounds=yes --project=.buildkite test/Operators/finitedifference/unit_fd_ops_shared_memory.jl"
- "julia --color=yes --project=.buildkite test/Operators/finitedifference/benchmark_fd_ops_shared_memory.jl"
env:
CLIMACOMMS_DEVICE: "CUDA"
agents:
slurm_gpus: 1

- label: "Unit: column"
key: unit_column
command:
Expand Down
2 changes: 2 additions & 0 deletions ext/ClimaCoreCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ include(joinpath("cuda", "operators_integral.jl"))
include(joinpath("cuda", "remapping_interpolate_array.jl"))
include(joinpath("cuda", "limiters.jl"))
include(joinpath("cuda", "operators_sem_shmem.jl"))
include(joinpath("cuda", "operators_fd_shmem_common.jl"))
include(joinpath("cuda", "operators_fd_shmem.jl"))
include(joinpath("cuda", "operators_thomas_algorithm.jl"))
include(joinpath("cuda", "matrix_fields_single_field_solve.jl"))
include(joinpath("cuda", "matrix_fields_multiple_field_solve.jl"))
Expand Down
30 changes: 30 additions & 0 deletions ext/cuda/data_layouts_threadblock.jl
Original file line number Diff line number Diff line change
Expand Up @@ -289,3 +289,33 @@ end
ij,
slabidx,
) = Operators.is_valid_index(space, ij, slabidx)

##### shmem fd kernel partition
@inline function fd_stencil_partition(
us::DataLayouts.UniversalSize,
n_face_levels::Integer,
n_max_threads::Integer = 256;
)
(Nq, _, _, Nv, Nh) = DataLayouts.universal_size(us)
Nvthreads = n_face_levels
@assert Nvthreads <= maximum_allowable_threads()[1] "Number of vertical face levels cannot exceed $(maximum_allowable_threads()[1])"
Nvblocks = cld(Nv, Nvthreads) # +1 may be needed to guarantee that shared memory is populated at the last cell face
return (;
threads = (Nvthreads,),
blocks = (Nh, Nvblocks, Nq * Nq),
Nvthreads,
)
end
@inline function fd_stencil_universal_index(space::Spaces.AbstractSpace, us)
(tv,) = CUDA.threadIdx()
(h, bv, ij) = CUDA.blockIdx()
v = tv + (bv - 1) * CUDA.blockDim().x
(Nq, _, _, _, _) = DataLayouts.universal_size(us)
if Nq * Nq < ij
return CartesianIndex((-1, -1, 1, -1, -1))
end
@inbounds (i, j) = CartesianIndices((Nq, Nq))[ij].I
return CartesianIndex((i, j, 1, v, h))
end
@inline fd_stencil_is_valid_index(I::CI5, us::UniversalSize) =
1 I[5] DataLayouts.get_Nh(us)
94 changes: 94 additions & 0 deletions ext/cuda/operators_fd_shmem.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import ClimaCore: DataLayouts, Spaces, Geometry, RecursiveApply, DataLayouts
import CUDA
import ClimaCore.Operators: return_eltype, get_local_geometry

Base.@propagate_inbounds function fd_operator_shmem(
space,
::Val{Nvt},
op::Operators.DivergenceF2C,
args...,
) where {Nvt}
# allocate temp output
RT = return_eltype(op, args...)
Ju³ = CUDA.CuStaticSharedArray(RT, (Nvt,))
return Ju³
end

Base.@propagate_inbounds function fd_operator_fill_shmem_interior!(
op::Operators.DivergenceF2C,
Ju³,
loc, # can be any location
space,
idx::Utilities.PlusHalf,
hidx,
arg,
)
@inbounds begin
vt = threadIdx().x
lg = Geometry.LocalGeometry(space, idx, hidx)
= Operators.getidx(space, arg, loc, idx, hidx)
Ju³[vt] = Geometry.Jcontravariant3(u³, lg)
end
return nothing
end

Base.@propagate_inbounds function fd_operator_fill_shmem_left_boundary!(
op::Operators.DivergenceF2C,
bc::Operators.SetValue,
Ju³,
loc,
space,
idx::Utilities.PlusHalf,
hidx,
arg,
)
idx == Operators.left_face_boundary_idx(space) ||
error("Incorrect left idx")
@inbounds begin
vt = threadIdx().x
lg = Geometry.LocalGeometry(space, idx, hidx)
= Operators.getidx(space, bc.val, loc, nothing, hidx)
Ju³[vt] = Geometry.Jcontravariant3(u³, lg)
end
return nothing
end

Base.@propagate_inbounds function fd_operator_fill_shmem_right_boundary!(
op::Operators.DivergenceF2C,
bc::Operators.SetValue,
Ju³,
loc,
space,
idx::Utilities.PlusHalf,
hidx,
arg,
)
# The right boundary is called at `idx + 1`, so we need to subtract 1 from idx (shmem is loaded at vt+1)
idx == Operators.right_face_boundary_idx(space) ||
error("Incorrect right idx")
@inbounds begin
vt = threadIdx().x
lg = Geometry.LocalGeometry(space, idx, hidx)
= Operators.getidx(space, bc.val, loc, nothing, hidx)
Ju³[vt] = Geometry.Jcontravariant3(u³, lg)
end
return nothing
end

Base.@propagate_inbounds function fd_operator_evaluate(
op::Operators.DivergenceF2C,
Ju³,
loc,
space,
idx::Integer,
hidx,
args...,
)
@inbounds begin
vt = threadIdx().x
local_geometry = Geometry.LocalGeometry(space, idx, hidx)
Ju³₋ = Ju³[vt] # corresponds to idx - half
Ju³₊ = Ju³[vt + 1] # corresponds to idx + half
return (Ju³₊ Ju³₋) local_geometry.invJ
end
end
Loading

0 comments on commit d807e6e

Please sign in to comment.