From 8d6a613ddffd2cc07666b30bb79a08eb4808ecb1 Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Wed, 5 Feb 2025 17:01:17 -0500 Subject: [PATCH] Implement FD shmem Try dont_limit on recursive resolve_shmem methods Fixes + more dont limit Matrix field fixes Matrix field fixes DivergenceF2C fix MatrixField fixes Qualify DivergenceF2C wip Refactor + fixed space bug. All seems good. --- ext/ClimaCoreCUDAExt.jl | 2 + ext/cuda/data_layouts_threadblock.jl | 30 ++ ext/cuda/operators_fd_shmem.jl | 94 +++++ ext/cuda/operators_fd_shmem_common.jl | 374 ++++++++++++++++++ ext/cuda/operators_fd_shmem_is_supported.jl | 95 +++++ ext/cuda/operators_finite_difference.jl | 90 ++++- src/Operators/finitedifference.jl | 52 ++- src/Operators/spectralelement.jl | 1 + .../benchmark_fd_ops_shared_memory.jl | 71 ++++ .../finitedifference/opt_examples.jl | 2 +- .../unit_fd_ops_shared_memory.jl | 73 ++++ .../utils_fd_ops_shared_memory.jl | 146 +++++++ 12 files changed, 1017 insertions(+), 13 deletions(-) create mode 100644 ext/cuda/operators_fd_shmem.jl create mode 100644 ext/cuda/operators_fd_shmem_common.jl create mode 100644 ext/cuda/operators_fd_shmem_is_supported.jl create mode 100644 test/Operators/finitedifference/benchmark_fd_ops_shared_memory.jl create mode 100644 test/Operators/finitedifference/unit_fd_ops_shared_memory.jl create mode 100644 test/Operators/finitedifference/utils_fd_ops_shared_memory.jl diff --git a/ext/ClimaCoreCUDAExt.jl b/ext/ClimaCoreCUDAExt.jl index dbcf984105..7a4278a8e6 100644 --- a/ext/ClimaCoreCUDAExt.jl +++ b/ext/ClimaCoreCUDAExt.jl @@ -32,6 +32,8 @@ include(joinpath("cuda", "operators_integral.jl")) include(joinpath("cuda", "remapping_interpolate_array.jl")) include(joinpath("cuda", "limiters.jl")) include(joinpath("cuda", "operators_sem_shmem.jl")) +include(joinpath("cuda", "operators_fd_shmem_common.jl")) +include(joinpath("cuda", "operators_fd_shmem.jl")) include(joinpath("cuda", "operators_thomas_algorithm.jl")) include(joinpath("cuda", "matrix_fields_single_field_solve.jl")) include(joinpath("cuda", "matrix_fields_multiple_field_solve.jl")) diff --git a/ext/cuda/data_layouts_threadblock.jl b/ext/cuda/data_layouts_threadblock.jl index 43f64319af..d7f980b441 100644 --- a/ext/cuda/data_layouts_threadblock.jl +++ b/ext/cuda/data_layouts_threadblock.jl @@ -289,3 +289,33 @@ end ij, slabidx, ) = Operators.is_valid_index(space, ij, slabidx) + +##### shmem fd kernel partition +@inline function fd_stencil_partition( + us::DataLayouts.UniversalSize, + n_face_levels::Integer, + n_max_threads::Integer = 256; +) + (Nq, _, _, Nv, Nh) = DataLayouts.universal_size(us) + Nvthreads = n_face_levels + @assert Nvthreads <= maximum_allowable_threads()[1] "Number of vertical face levels cannot exceed $(maximum_allowable_threads()[1])" + Nvblocks = cld(Nv, Nvthreads) # +1 may be needed to guarantee that shared memory is populated at the last cell face + return (; + threads = (Nvthreads,), + blocks = (Nh, Nvblocks, Nq * Nq), + Nvthreads, + ) +end +@inline function fd_stencil_universal_index(space::Spaces.AbstractSpace, us) + (tv,) = CUDA.threadIdx() + (h, bv, ij) = CUDA.blockIdx() + v = tv + (bv - 1) * CUDA.blockDim().x + (Nq, _, _, _, _) = DataLayouts.universal_size(us) + if Nq * Nq < ij + return CartesianIndex((-1, -1, 1, -1, -1)) + end + @inbounds (i, j) = CartesianIndices((Nq, Nq))[ij].I + return CartesianIndex((i, j, 1, v, h)) +end +@inline fd_stencil_is_valid_index(I::CI5, us::UniversalSize) = + 1 ≤ I[5] ≤ DataLayouts.get_Nh(us) diff --git a/ext/cuda/operators_fd_shmem.jl b/ext/cuda/operators_fd_shmem.jl new file mode 100644 index 0000000000..5ec82fea84 --- /dev/null +++ b/ext/cuda/operators_fd_shmem.jl @@ -0,0 +1,94 @@ +import ClimaCore: DataLayouts, Spaces, Geometry, RecursiveApply, DataLayouts +import CUDA +import ClimaCore.Operators: return_eltype, get_local_geometry + +Base.@propagate_inbounds function fd_operator_shmem( + space, + ::Val{Nvt}, + op::Operators.DivergenceF2C, + args..., +) where {Nvt} + # allocate temp output + RT = return_eltype(op, args...) + Ju³ = CUDA.CuStaticSharedArray(RT, (Nvt,)) + return Ju³ +end + +Base.@propagate_inbounds function fd_operator_fill_shmem_interior!( + op::Operators.DivergenceF2C, + Ju³, + loc, # can be any location + space, + idx::Utilities.PlusHalf, + hidx, + arg, +) + @inbounds begin + vt = threadIdx().x + lg = Geometry.LocalGeometry(space, idx, hidx) + u³ = Operators.getidx(space, arg, loc, idx, hidx) + Ju³[vt] = Geometry.Jcontravariant3(u³, lg) + end + return nothing +end + +Base.@propagate_inbounds function fd_operator_fill_shmem_left_boundary!( + op::Operators.DivergenceF2C, + bc::Operators.SetValue, + Ju³, + loc, + space, + idx::Utilities.PlusHalf, + hidx, + arg, +) + idx == Operators.left_face_boundary_idx(space) || + error("Incorrect left idx") + @inbounds begin + vt = threadIdx().x + lg = Geometry.LocalGeometry(space, idx, hidx) + u³ = Operators.getidx(space, bc.val, loc, nothing, hidx) + Ju³[vt] = Geometry.Jcontravariant3(u³, lg) + end + return nothing +end + +Base.@propagate_inbounds function fd_operator_fill_shmem_right_boundary!( + op::Operators.DivergenceF2C, + bc::Operators.SetValue, + Ju³, + loc, + space, + idx::Utilities.PlusHalf, + hidx, + arg, +) + # The right boundary is called at `idx + 1`, so we need to subtract 1 from idx (shmem is loaded at vt+1) + idx == Operators.right_face_boundary_idx(space) || + error("Incorrect right idx") + @inbounds begin + vt = threadIdx().x + lg = Geometry.LocalGeometry(space, idx, hidx) + u³ = Operators.getidx(space, bc.val, loc, nothing, hidx) + Ju³[vt] = Geometry.Jcontravariant3(u³, lg) + end + return nothing +end + +Base.@propagate_inbounds function fd_operator_evaluate( + op::Operators.DivergenceF2C, + Ju³, + loc, + space, + idx::Integer, + hidx, + args..., +) + @inbounds begin + vt = threadIdx().x + local_geometry = Geometry.LocalGeometry(space, idx, hidx) + Ju³₋ = Ju³[vt] # corresponds to idx - half + Ju³₊ = Ju³[vt + 1] # corresponds to idx + half + return (Ju³₊ ⊟ Ju³₋) ⊠ local_geometry.invJ + end +end diff --git a/ext/cuda/operators_fd_shmem_common.jl b/ext/cuda/operators_fd_shmem_common.jl new file mode 100644 index 0000000000..62e44e29e8 --- /dev/null +++ b/ext/cuda/operators_fd_shmem_common.jl @@ -0,0 +1,374 @@ +import ClimaCore: DataLayouts, Spaces, Geometry, RecursiveApply, DataLayouts +import CUDA +import ClimaCore.Operators: return_eltype, get_local_geometry +import ClimaCore.Operators: getidx +import ClimaCore.Utilities: PlusHalf +import ClimaCore.Utilities + +Base.@propagate_inbounds function getidx( + space, + bc::StencilBroadcasted{CUDAColumnStencilStyle}, + loc::Interior, + idx, + hidx, +) + if Operators.fd_shmem_is_supported(bc) + return fd_operator_evaluate( + bc.op, + bc.work, + loc, + space, + idx, + hidx, + bc.args..., + ) + end + Operators.stencil_interior(bc.op, loc, space, idx, hidx, bc.args...) +end + + +Base.@propagate_inbounds function getidx( + parent_space, + bc::StencilBroadcasted{CUDAColumnStencilStyle}, + loc::Operators.LeftBoundaryWindow, + idx, + hidx, +) + space = axes(bc) + if Operators.fd_shmem_is_supported(bc) + return fd_operator_evaluate( + bc.op, + bc.work, + loc, + space, + idx, + hidx, + bc.args..., + ) + end + op = bc.op + if call_left_boundary(idx, space, bc, loc) + Operators.stencil_left_boundary( + op, + Operators.get_boundary(op, loc), + loc, + space, + idx, + hidx, + bc.args..., + ) + else + # fallback to interior stencil + Operators.stencil_interior(op, loc, space, idx, hidx, bc.args...) + end +end + +Base.@propagate_inbounds function getidx( + parent_space, + bc::StencilBroadcasted{CUDAColumnStencilStyle}, + loc::Operators.RightBoundaryWindow, + idx, + hidx, +) + space = axes(bc) + if Operators.fd_shmem_is_supported(bc) + return fd_operator_evaluate( + bc.op, + bc.work, + loc, + space, + idx, + hidx, + bc.args..., + ) + end + op = bc.op + if call_right_boundary(idx, space, bc, loc) + Operators.stencil_right_boundary( + op, + Operators.get_boundary(op, loc), + loc, + space, + idx, + hidx, + bc.args..., + ) + else + # fallback to interior stencil + Operators.stencil_interior(op, loc, space, idx, hidx, bc.args...) + end +end + +""" + fd_allocate_shmem(Val(Nvt), b) + +Create a new broadcasted object with necessary share memory allocated, +using `Nvt` nodal points per block. +""" +@inline function fd_allocate_shmem(::Val{Nvt}, obj) where {Nvt} + obj +end +@inline function fd_allocate_shmem( + ::Val{Nvt}, + bc::Broadcasted{Style}, +) where {Nvt, Style} + Broadcasted{Style}(bc.f, _fd_allocate_shmem(Val(Nvt), bc.args...), bc.axes) +end + +######### MatrixFields +# MatrixField operators are not yet supported, and we must stop recursing because +# we can have something of the form +# MatrixFields.LazyOneArgFDOperatorMatrix{DivergenceF2C{@NamedTuple{}}}(DivergenceF2C{@NamedTuple{}}(NamedTuple())) +# which `fd_shmem_is_supported` will return `true` for. + +@inline fd_allocate_shmem(_, bc::MatrixFields.LazyOperatorBroadcasted) = bc +@inline fd_allocate_shmem(_, bc::MatrixFields.FDOperatorMatrix) = bc +@inline fd_allocate_shmem(_, bc::MatrixFields.LazyOneArgFDOperatorMatrix) = bc +######### + +@inline function fd_allocate_shmem( + ::Val{Nvt}, + sbc::StencilBroadcasted{Style}, +) where {Nvt, Style} + args = _fd_allocate_shmem(Val(Nvt), sbc.args...) + work = if Operators.fd_shmem_is_supported(sbc) + fd_operator_shmem(sbc.axes, Val(Nvt), sbc.op, args...) + else + nothing + end + StencilBroadcasted{Style}(sbc.op, args, sbc.axes, work) +end + +@inline _fd_allocate_shmem(::Val{Nvt}) where {Nvt} = () +@inline _fd_allocate_shmem(::Val{Nvt}, arg, xargs...) where {Nvt} = ( + fd_allocate_shmem(Val(Nvt), arg), + _fd_allocate_shmem(Val(Nvt), xargs...)..., +) + +@inline function call_left_boundary(idx::T, space, bc, loc) where {T} + (; op) = bc + return Operators.has_boundary(op, loc) && + idx < Operators.left_interior_idx( + space, + op, + Operators.get_boundary(op, loc), + bc.args..., + ) +end +@inline function call_right_boundary(idx, space, bc, loc) + (; op) = bc + return Operators.has_boundary(op, loc) && + idx > Operators.right_interior_idx( + space, + bc.op, + Operators.get_boundary(bc.op, loc), + bc.args..., + ) +end + +get_arg_space(bc::StencilBroadcasted{CUDAColumnStencilStyle}, args::Tuple{}) = + axes(bc) +get_arg_space(bc::StencilBroadcasted{CUDAColumnStencilStyle}, args::Tuple) = + axes(args[1]) + +get_cent_idx(idx::Integer) = idx +get_face_idx(idx::PlusHalf) = idx +get_cent_idx(idx::PlusHalf) = idx + half +get_face_idx(idx::Integer) = idx - half + +""" + fd_resolve_shmem!( + sbc::StencilBroadcasted, + idx, + hidx, + bds + ) + +Recursively stores the arguments to all operators into shared memory, at the +given indices (if they are valid). +""" +Base.@propagate_inbounds function fd_resolve_shmem!( + sbc::StencilBroadcasted{CUDAColumnStencilStyle}, + idx, # top-level index + hidx, + bds, +) + (li, lw, rw, ri) = bds + space = axes(sbc) + + ᶜspace = Spaces.center_space(space) + ᶠspace = Spaces.face_space(space) + arg_space = get_arg_space(sbc, sbc.args) + ᶜidx = get_cent_idx(idx) + ᶠidx = get_face_idx(idx) + + _fd_resolve_shmem!(idx, hidx, bds, sbc.args...) # propagate idx, not bc_idx recursively through broadcast expressions + + # After recursion, check if shmem is supported for this operator + Operators.fd_shmem_is_supported(sbc) || return nothing + + (; op) = sbc + lloc = Operators.LeftBoundaryWindow{Spaces.left_boundary_name(space)}() + rloc = Operators.RightBoundaryWindow{Spaces.right_boundary_name(space)}() + iloc = Operators.Interior() + + IP = Topologies.isperiodic(Spaces.vertical_topology(space)) + + # Ideally, we would use something like `loc = get_location(space, bds, idx)` and dispatch on the location, + # but that is type unstable and the compiler emits very peculiar illegal memory access errors: + # `illegal memory access was encountered (code 700, ERROR_ILLEGAL_ADDRESS)` + # when using this pattern. Instead we can use dynamic branching (if-else) where loc + # inside each if-else is statically known. + bc_bds = Operators.window_bounds(space, sbc) + (bc_li, bc_lw, bc_rw, bc_ri) = bc_bds + if arg_space isa Operators.AllFaceFiniteDifferenceSpace # populate shmem on faces + if IP || get_face_idx(bc_lw) ≤ ᶠidx ≤ get_face_idx(bc_rw) + 1 # interior + fd_operator_fill_shmem_interior!( + sbc.op, + sbc.work, + iloc, + space, + ᶠidx, + hidx, + sbc.args..., + ) + elseif ᶠidx < get_face_idx(bc_lw) && Operators.has_boundary(op, lloc) # left + fd_operator_fill_shmem_left_boundary!( + sbc.op, + Operators.get_boundary(op, lloc), + sbc.work, + lloc, + space, + ᶠidx, + hidx, + sbc.args..., + ) + elseif ᶠidx > get_face_idx(bc_rw) && Operators.has_boundary(op, rloc) # right + fd_operator_fill_shmem_right_boundary!( + sbc.op, + Operators.get_boundary(op, rloc), + sbc.work, + rloc, + space, + ᶠidx, + hidx, + sbc.args..., + ) + elseif ᶠidx < get_face_idx(bc_lw) && !Operators.has_boundary(op, lloc) # left + fd_operator_fill_shmem_interior!( + sbc.op, + sbc.work, + lloc, + space, + ᶠidx, + hidx, + sbc.args..., + ) + elseif ᶠidx > get_face_idx(bc_rw) && !Operators.has_boundary(op, rloc) # right + fd_operator_fill_shmem_interior!( + sbc.op, + sbc.work, + rloc, + space, + ᶠidx, + hidx, + sbc.args..., + ) + else # this else should never run + end + else # populate shmem on centers + if IP || get_cent_idx(bc_lw) ≤ ᶜidx ≤ get_cent_idx(bc_rw) + 1 # interior + fd_operator_fill_shmem_interior!( + sbc.op, + sbc.work, + iloc, + space, + ᶜidx, + hidx, + sbc.args..., + ) + elseif ᶜidx < get_cent_idx(bc_lw) && Operators.has_boundary(op, lloc) # left + fd_operator_fill_shmem_left_boundary!( + sbc.op, + Operators.get_boundary(op, lloc), + sbc.work, + lloc, + space, + ᶜidx, + hidx, + sbc.args..., + ) + elseif ᶜidx > get_cent_idx(bc_rw) && Operators.has_boundary(op, rloc) # right + fd_operator_fill_shmem_right_boundary!( + sbc.op, + Operators.get_boundary(op, rloc), + sbc.work, + rloc, + space, + ᶜidx, + hidx, + sbc.args..., + ) + elseif ᶜidx < get_cent_idx(bc_lw) && !Operators.has_boundary(op, lloc) # left + fd_operator_fill_shmem_interior!( + sbc.op, + Operators.get_boundary(op, lloc), + sbc.work, + lloc, + space, + ᶜidx, + hidx, + sbc.args..., + ) + elseif ᶜidx > get_cent_idx(bc_rw) && !Operators.has_boundary(op, rloc) # right + fd_operator_fill_shmem_interior!( + sbc.op, + Operators.get_boundary(op, rloc), + sbc.work, + rloc, + space, + ᶜidx, + hidx, + sbc.args..., + ) + else # this else should never run + end + end + return nothing +end + +Base.@propagate_inbounds _fd_resolve_shmem!(idx, hidx, bds) = nothing +Base.@propagate_inbounds function _fd_resolve_shmem!( + idx, + hidx, + bds, + arg, + xargs..., +) + fd_resolve_shmem!(arg, idx, hidx, bds) + _fd_resolve_shmem!(idx, hidx, bds, xargs...) +end + +Base.@propagate_inbounds fd_resolve_shmem!( + bc::Broadcasted{CUDAColumnStencilStyle}, + idx, + hidx, + bds, +) = _fd_resolve_shmem!(idx, hidx, bds, bc.args...) +@inline fd_resolve_shmem!(obj, idx, hidx, bds) = nothing + +if hasfield(Method, :recursion_relation) + dont_limit = (args...) -> true + for m in methods(fd_resolve_shmem!) + m.recursion_relation = dont_limit + end + for m in methods(_fd_resolve_shmem!) + m.recursion_relation = dont_limit + end + for m in methods(_fd_allocate_shmem) + m.recursion_relation = dont_limit + end + for m in methods(fd_allocate_shmem) + m.recursion_relation = dont_limit + end +end diff --git a/ext/cuda/operators_fd_shmem_is_supported.jl b/ext/cuda/operators_fd_shmem_is_supported.jl new file mode 100644 index 0000000000..e7e7adbfe7 --- /dev/null +++ b/ext/cuda/operators_fd_shmem_is_supported.jl @@ -0,0 +1,95 @@ +import ClimaCore.MatrixFields +import ClimaCore.Operators: any_fd_shmem_supported + +@inline _any_fd_shmem_supported_args(falsesofar, args::Tuple, rargs...) = + falsesofar && + _any_fd_shmem_supported(falsesofar, args[1], rargs...) && + _any_fd_shmem_supported_args(falsesofar, Base.tail(args), rargs...) + +@inline _any_fd_shmem_supported_args(falsesofar, args::Tuple{Any}, rargs...) = + falsesofar && _any_fd_shmem_supported(falsesofar, args[1], rargs...) +@inline _any_fd_shmem_supported_args(falsesofar, args::Tuple{}, rargs...) = + falsesofar + +@inline function _any_fd_shmem_supported( + falsesofar, + bc::Base.Broadcast.Broadcasted, +) + return falsesofar && _any_fd_shmem_supported_args(falsesofar, bc.args) +end + +@inline _any_fd_shmem_supported(falsesofar, _, x::AbstractData) = false +@inline _any_fd_shmem_supported(falsesofar, _, x) = falsesofar + +@inline any_fd_shmem_supported(bc) = any_fd_shmem_supported(false, bc) + +@inline any_fd_shmem_supported(falsesofar, bc::StencilBroadcasted) = + falsesofar || + Operators.fd_shmem_is_supported(bc) || + _any_fd_shmem_supported_args(falsesofar, bc.args) + +@inline any_fd_shmem_supported(falsesofar, bc::Operators.Operator2Stencil) = + falsesofar || Operators.fd_shmem_is_supported(bc) + +@inline any_fd_shmem_supported(falsesofar, bc::Operators.ComposeStencils) = + falsesofar || Operators.fd_shmem_is_supported(bc) + +@inline any_fd_shmem_supported(falsesofar, bc::Operators.ApplyStencil) = + falsesofar || Operators.fd_shmem_is_supported(bc) + +@inline any_fd_shmem_supported(falsesofar, bc::Base.Broadcast.Broadcasted) = + falsesofar || _any_fd_shmem_supported_args(falsesofar, bc.args) + +@inline any_fd_shmem_supported(bc::Base.Broadcast.Broadcasted) = + _any_fd_shmem_supported_args(false, bc.args) + + +# Fallback is false: +@inline Operators.fd_shmem_is_supported(bc::StencilBroadcasted) = + Operators.fd_shmem_is_supported(bc.op) + +##### MatrixFields +@inline Operators.fd_shmem_is_supported(op::Operators.Operator2Stencil) = false + +@inline Operators.fd_shmem_is_supported( + bc::MatrixFields.LazyOperatorBroadcasted, +) = false + +@inline Operators.fd_shmem_is_supported(bc::Operators.ApplyStencil) = false + +@inline Operators.fd_shmem_is_supported(bc::Operators.ComposeStencils) = false + +@inline Operators.fd_shmem_is_supported(op::MatrixFields.FDOperatorMatrix) = + false + +@inline Operators.fd_shmem_is_supported( + op::MatrixFields.LazyOneArgFDOperatorMatrix, +) = false +##### + +@inline Operators.fd_shmem_is_supported(op::Operators.AbstractOperator) = + Operators.fd_shmem_is_supported(op, op.bcs) + +@inline Operators.fd_shmem_is_supported( + op::MatrixFields.MultiplyColumnwiseBandMatrixField, +) = false + +@inline Operators.fd_shmem_is_supported( + op::Operators.AbstractOperator, + bcs::NamedTuple, +) = false + +# Add cases here where shmem is supported: +@inline Operators.fd_shmem_is_supported(op::Operators.DivergenceF2C) = + Operators.fd_shmem_is_supported(op, op.bcs) +@inline Operators.fd_shmem_is_supported( + op::Operators.DivergenceF2C, + ::@NamedTuple{}, +) = true +@inline Operators.fd_shmem_is_supported( + op::Operators.DivergenceF2C, + bcs::NamedTuple, +) = + all(values(bcs)) do bc + all(supported_bc -> bc isa supported_bc, (Operators.SetValue,)) + end diff --git a/ext/cuda/operators_finite_difference.jl b/ext/cuda/operators_finite_difference.jl index 870de35083..6d7c717926 100644 --- a/ext/cuda/operators_finite_difference.jl +++ b/ext/cuda/operators_finite_difference.jl @@ -2,6 +2,8 @@ import ClimaCore: Spaces, Quadratures, Topologies import Base.Broadcast: Broadcasted import ClimaComms using CUDA: @cuda +import ClimaCore.Utilities: half +import ClimaCore.Operators import ClimaCore.Operators: AbstractStencilStyle, strip_space import ClimaCore.Operators: setidx!, getidx import ClimaCore.Operators: StencilBroadcasted @@ -10,6 +12,8 @@ import ClimaCore.Operators: LeftBoundaryWindow, RightBoundaryWindow, Interior struct CUDAColumnStencilStyle <: AbstractStencilStyle end AbstractStencilStyle(::ClimaComms.CUDADevice) = CUDAColumnStencilStyle +include("operators_fd_shmem_is_supported.jl") + function Base.copyto!( out::Field, bc::Union{ @@ -21,19 +25,39 @@ function Base.copyto!( bounds = Operators.window_bounds(space, bc) out_fv = Fields.field_values(out) us = DataLayouts.UniversalSize(out_fv) - args = - (strip_space(out, space), strip_space(bc, space), axes(out), bounds, us) - - threads = threads_via_occupancy(copyto_stencil_kernel!, args) - n_max_threads = min(threads, get_N(us)) - p = partition(out_fv, n_max_threads) + fspace = Spaces.face_space(space) + n_face_levels = Spaces.nlevels(fspace) + p = fd_stencil_partition(us, n_face_levels) + args = ( + strip_space(out, space), + strip_space(bc, space), + axes(out), + bounds, + us, + Val(p.Nvthreads), + ) auto_launch!( - copyto_stencil_kernel!, + copyto_stencil_kernel_shmem!, args; threads_s = p.threads, blocks_s = p.blocks, ) + # else + # args = + # (strip_space(out, space), strip_space(bc, space), axes(out), bounds, us) + + # threads = threads_via_occupancy(copyto_stencil_kernel!, args) + # n_max_threads = min(threads, get_N(us)) + # p = partition(out_fv, n_max_threads) + + # auto_launch!( + # copyto_stencil_kernel!, + # args; + # threads_s = p.threads, + # blocks_s = p.blocks, + # ) + # end call_post_op_callback() && post_op_callback(out, out, bc) return out end @@ -64,3 +88,55 @@ function copyto_stencil_kernel!(out, bc, space, bds, us) end return nothing end + + +function copyto_stencil_kernel_shmem!( + out, + bc′, + space, + bds, + us, + ::Val{Nvt}, +) where {Nvt} + @inbounds begin + out_fv = Fields.field_values(out) + us = DataLayouts.UniversalSize(out_fv) + I = fd_stencil_universal_index(space, us) + if fd_stencil_is_valid_index(I, us) # check that hidx is in bounds + (li, lw, rw, ri) = bds + (i, j, _, v, h) = I.I + hidx = (i, j, h) + idx = v - 1 + li + bc = Operators.reconstruct_placeholder_broadcasted(space, bc′) + bc_shmem = fd_allocate_shmem(Val(Nvt), bc) # allocates shmem + + fd_resolve_shmem!(bc_shmem, idx, hidx, bds) # recursively fills shmem + CUDA.sync_threads() + + nv = Spaces.nlevels(space) + isactive = if space isa Operators.AllFaceFiniteDifferenceSpace # check that idx is in bounds + idx + half <= nv + else + idx <= nv + end + if isactive + # Call getidx overloaded in operators_fd_shmem_common.jl + if li <= idx <= (lw - 1) + lwindow = + LeftBoundaryWindow{Spaces.left_boundary_name(space)}() + val = Operators.getidx(space, bc_shmem, lwindow, idx, hidx) + elseif (rw + 1) <= idx <= ri + rwindow = + RightBoundaryWindow{Spaces.right_boundary_name(space)}() + val = Operators.getidx(space, bc_shmem, rwindow, idx, hidx) + else + # @assert lw <= idx <= rw + iwindow = Interior() + val = Operators.getidx(space, bc_shmem, iwindow, idx, hidx) + end + setidx!(space, out, idx, hidx, val) + end + end + end + return nothing +end diff --git a/src/Operators/finitedifference.jl b/src/Operators/finitedifference.jl index f9f4dd49f3..e3598cf984 100644 --- a/src/Operators/finitedifference.jl +++ b/src/Operators/finitedifference.jl @@ -201,12 +201,12 @@ get_boundary( has_boundary( op::FiniteDifferenceOperator, ::LeftBoundaryWindow{name}, -) where {name} = hasproperty(op.bcs, name) +) where {name} = hasfield(typeof(op.bcs), name) has_boundary( op::FiniteDifferenceOperator, ::RightBoundaryWindow{name}, -) where {name} = hasproperty(op.bcs, name) +) where {name} = hasfield(typeof(op.bcs), name) strip_space(op::FiniteDifferenceOperator, parent_space) = unionall_type(typeof(op))( @@ -231,17 +231,20 @@ This is similar to a `Base.Broadcast.Broadcasted` object. This is returned by `Base.Broadcast.broadcasted(op::FiniteDifferenceOperator)`. """ -struct StencilBroadcasted{Style, Op, Args, Axes} <: OperatorBroadcasted{Style} +struct StencilBroadcasted{Style, Op, Args, Axes, Work} <: + OperatorBroadcasted{Style} op::Op args::Args axes::Axes + work::Work end StencilBroadcasted{Style}( op::Op, args::Args, axes::Axes = nothing, -) where {Style, Op, Args, Axes} = - StencilBroadcasted{Style, Op, Args, Axes}(op, args, axes) + work::Work = nothing, +) where {Style, Op, Args, Axes, Work} = + StencilBroadcasted{Style, Op, Args, Axes, Work}(op, args, axes, work) Adapt.adapt_structure(to, sbc::StencilBroadcasted{Style}) where {Style} = StencilBroadcasted{Style}( @@ -4076,6 +4079,16 @@ function Base.copyto!( return _serial_copyto!(field_out, bc, Ni, Nj, Nh) end +@inline function reconstruct_placeholder_broadcasted( + parent_space::Spaces.AbstractSpace, + sbc::StencilBroadcasted{Style}, +) where {Style} + space = reconstruct_placeholder_space(axes(sbc), parent_space) + args = _reconstruct_placeholder_broadcasted(space, sbc.args...) + return StencilBroadcasted{Style}(sbc.op, args, space, sbc.work) +end + + function window_bounds(space, bc) if Topologies.isperiodic(Spaces.vertical_topology(space)) li = lw = left_idx(space) @@ -4149,3 +4162,32 @@ end #else # return v ⊠ (a⁺ ⊟ RecursiveApply.rdiv((a⁺ - a⁻) ⊠ 𝜙 ,2)) # Current working solution #end + +""" + fd_shmem_is_supported(bc::Base.Broadcast.AbstractBroadcasted) + +Returns a Bool indicating whether or not the broadcasted object supports +shared memory, allowing us to dispatch into an optimized kernel. + +This function and dispatch should be removed once all operators support +shared memory. +""" +function fd_shmem_is_supported end + +""" + any_fd_shmem_supported(::Base.Broadcast.AbstractBroadcasted) + +Returns a Bool indicating if any operators in the broadcasted object support +finite difference shared memory shmem. +""" +function any_fd_shmem_supported end + +if hasfield(Method, :recursion_relation) + dont_limit = (args...) -> true + for m in methods(reconstruct_placeholder_broadcasted) + m.recursion_relation = dont_limit + end + for m in methods(_reconstruct_placeholder_broadcasted) + m.recursion_relation = dont_limit + end +end diff --git a/src/Operators/spectralelement.jl b/src/Operators/spectralelement.jl index 284b9c0eec..e588387168 100644 --- a/src/Operators/spectralelement.jl +++ b/src/Operators/spectralelement.jl @@ -91,6 +91,7 @@ Adapt.adapt_structure(to, sbc::SpectralBroadcasted{Style}) where {Style} = sbc.op, Adapt.adapt(to, sbc.args), Adapt.adapt(to, sbc.axes), + Adapt.adapt(to, sbc.work), ) return_space(::SpectralElementOperator, space) = space diff --git a/test/Operators/finitedifference/benchmark_fd_ops_shared_memory.jl b/test/Operators/finitedifference/benchmark_fd_ops_shared_memory.jl new file mode 100644 index 0000000000..2d2ce85052 --- /dev/null +++ b/test/Operators/finitedifference/benchmark_fd_ops_shared_memory.jl @@ -0,0 +1,71 @@ +#= +julia --project=.buildkite +using Revise; include("test/Operators/finitedifference/benchmark_fd_ops_shared_memory.jl") +=# +include("utils_fd_ops_shared_memory.jl") +using BenchmarkTools + +#! format: off +function bench_kernels!(fields) + (; f, ρ, ϕ) = fields + (; ᶜout1, ᶜout2, ᶜout3, ᶜout4, ᶜout5, ᶜout6, ᶜout7, ᶜout8) = fields + device = ClimaComms.device(f) + FT = Spaces.undertype(axes(ϕ)) + div_bcs = Operators.DivergenceF2C(; + bottom = Operators.SetValue(Geometry.Covariant3Vector(FT(100))), + top = Operators.SetValue(Geometry.Covariant3Vector(FT(10000))), + ) + div = Operators.DivergenceF2C() + ᶠwinterp = Operators.WeightedInterpolateC2F( + bottom = Operators.Extrapolate(), + top = Operators.Extrapolate(), + ) + println("ᶜout1: ", @benchmark ClimaComms.@cuda_sync $device begin + @. $ᶜout1 = $div(Geometry.WVector($f)) + end) + @. ᶜout2 = 0 + println("ᶜout2: ", @benchmark ClimaComms.@cuda_sync $device begin + @. $ᶜout2 += $div(Geometry.WVector($f) * 2) + end) + println("ᶜout3: ", @benchmark ClimaComms.@cuda_sync $device begin + @. $ᶜout3 = $div(Geometry.WVector($ᶠwinterp($ϕ, $ρ))) + end) + + println("ᶜout4: ", @benchmark ClimaComms.@cuda_sync $device begin + @. $ᶜout4 = $div_bcs(Geometry.WVector($f)) + end) + @. ᶜout5 = 0 + println("ᶜout5: ", @benchmark ClimaComms.@cuda_sync $device begin + @. $ᶜout5 += $div_bcs(Geometry.WVector($f) * 2) + end) + println("ᶜout6: ", @benchmark ClimaComms.@cuda_sync $device begin + @. $ᶜout6 = $div_bcs(Geometry.WVector($ᶠwinterp($ϕ, $ρ))) + end) + + # from the wild + Ic2f = Operators.InterpolateC2F(; top = Operators.Extrapolate()) + divf2c = Operators.DivergenceF2C(; bottom = Operators.SetValue(Geometry.Covariant3Vector(FT(100000000)))) + # only upward component of divergence + println("ᶜout7: ", @benchmark ClimaComms.@cuda_sync $device begin + @. $ᶜout7 = $divf2c(Geometry.WVector($Ic2f($ϕ))) # works + end) + println("ᶜout8: ", @benchmark ClimaComms.@cuda_sync $device begin + @. $ᶜout8 = $divf2c($Ic2f(Geometry.WVector($ϕ))) # breaks + end) + return nothing +end; + +#! format: on +ᶜspace_cpu = get_space_extruded(ClimaComms.CPUSingleThreaded(), Float64); +ᶠspace_cpu = Spaces.face_space(ᶜspace_cpu); +fields_cpu = (; get_fields(ᶜspace_cpu)..., get_fields(ᶠspace_cpu)...); +bench_kernels!(fields_cpu) +@info "Compiled CPU kernels" + +ᶜspace = get_space_extruded(ClimaComms.device(), Float64); +ᶠspace = Spaces.face_space(ᶜspace); +fields = (; get_fields(ᶜspace)..., get_fields(ᶠspace)...); +bench_kernels!(fields) +@info "Compiled GPU kernels" + +nothing diff --git a/test/Operators/finitedifference/opt_examples.jl b/test/Operators/finitedifference/opt_examples.jl index e2cac97405..9246deed52 100644 --- a/test/Operators/finitedifference/opt_examples.jl +++ b/test/Operators/finitedifference/opt_examples.jl @@ -443,7 +443,7 @@ end @testset "FD operator allocation tests" begin FT = Float64 - n_elems = 1000 + n_elems = 100 domain = Domains.IntervalDomain( Geometry.ZPoint{FT}(0.0), Geometry.ZPoint{FT}(pi); diff --git a/test/Operators/finitedifference/unit_fd_ops_shared_memory.jl b/test/Operators/finitedifference/unit_fd_ops_shared_memory.jl new file mode 100644 index 0000000000..0eda679488 --- /dev/null +++ b/test/Operators/finitedifference/unit_fd_ops_shared_memory.jl @@ -0,0 +1,73 @@ +#= +julia --project=.buildkite +julia --check-bounds=yes -g2 --project=.buildkite +using Revise; include("test/Operators/finitedifference/unit_fd_ops_shared_memory.jl") +=# +include("utils_fd_ops_shared_memory.jl") + +@testset "FD shared memory: dispatch" begin # this ensures that we exercise the correct code-path + FT = Float64 + device = ClimaComms.device() + @test device isa ClimaComms.CUDADevice + ᶜspace = get_space_column(device, FT) + ᶠspace = Spaces.face_space(ᶜspace) + f = Fields.Field(FT, ᶠspace) + grad = Operators.GradientF2C() + bc = @. lazy(grad(f)) + @test !Operators.any_fd_shmem_supported(bc) + div = Operators.DivergenceF2C() + bc = @. lazy(div(Geometry.WVector(f))) + @test Operators.any_fd_shmem_supported(bc) +end + +#! format: off +@testset "Correctness column" begin + ᶜspace_cpu = get_space_column(ClimaComms.CPUSingleThreaded(), Float64); + ᶠspace_cpu = Spaces.face_space(ᶜspace_cpu); + fields_cpu = (; get_fields(ᶜspace_cpu)..., get_fields(ᶠspace_cpu)...); + kernels!(fields_cpu) + @info "Compiled CPU kernels" + + ᶜspace = get_space_column(ClimaComms.device(), Float64); + ClimaComms.device(ᶜspace) isa ClimaComms.CPUSingleThreaded && @warn "Running on the CPU" + ᶠspace = Spaces.face_space(ᶜspace); + fields = (; get_fields(ᶜspace)..., get_fields(ᶠspace)...); + kernels!(fields) + @info "Compiled GPU kernels" + + @test compare_cpu_gpu(fields_cpu.ᶜout1, fields.ᶜout1); @test !is_trivial(fields_cpu.ᶜout1) + @test compare_cpu_gpu(fields_cpu.ᶜout2, fields.ᶜout2); @test !is_trivial(fields_cpu.ᶜout2) + @test compare_cpu_gpu(fields_cpu.ᶜout3, fields.ᶜout3); @test !is_trivial(fields_cpu.ᶜout3) + @test compare_cpu_gpu(fields_cpu.ᶜout4, fields.ᶜout4); @test !is_trivial(fields_cpu.ᶜout4) + @test compare_cpu_gpu(fields_cpu.ᶜout5, fields.ᶜout5); @test !is_trivial(fields_cpu.ᶜout5) + @test compare_cpu_gpu(fields_cpu.ᶜout6, fields.ᶜout6); @test !is_trivial(fields_cpu.ᶜout6) + @test compare_cpu_gpu(fields_cpu.ᶜout7, fields.ᶜout7); @test !is_trivial(fields_cpu.ᶜout7) + @test compare_cpu_gpu(fields_cpu.ᶜout8, fields.ᶜout8); @test !is_trivial(fields_cpu.ᶜout8) +end + +@testset "Correctness extruded cubed sphere" begin + ᶜspace_cpu = get_space_extruded(ClimaComms.CPUSingleThreaded(), Float64); + ᶠspace_cpu = Spaces.face_space(ᶜspace_cpu); + fields_cpu = (; get_fields(ᶜspace_cpu)..., get_fields(ᶠspace_cpu)...); + kernels!(fields_cpu) + @info "Compiled CPU kernels" + + ᶜspace = get_space_extruded(ClimaComms.device(), Float64); + ᶠspace = Spaces.face_space(ᶜspace); + fields = (; get_fields(ᶜspace)..., get_fields(ᶠspace)...); + kernels!(fields) + @info "Compiled GPU kernels" + + @test compare_cpu_gpu(fields_cpu.ᶜout1, fields.ᶜout1); @test !is_trivial(fields_cpu.ᶜout1) + @test compare_cpu_gpu(fields_cpu.ᶜout2, fields.ᶜout2); @test !is_trivial(fields_cpu.ᶜout2) + @test compare_cpu_gpu(fields_cpu.ᶜout3, fields.ᶜout3); @test !is_trivial(fields_cpu.ᶜout3) + @test compare_cpu_gpu(fields_cpu.ᶜout4, fields.ᶜout4); @test !is_trivial(fields_cpu.ᶜout4) + @test compare_cpu_gpu(fields_cpu.ᶜout5, fields.ᶜout5); @test !is_trivial(fields_cpu.ᶜout5) + @test compare_cpu_gpu(fields_cpu.ᶜout6, fields.ᶜout6); @test !is_trivial(fields_cpu.ᶜout6) + @test compare_cpu_gpu(fields_cpu.ᶜout7, fields.ᶜout7); @test !is_trivial(fields_cpu.ᶜout7) + @test compare_cpu_gpu(fields_cpu.ᶜout8, fields.ᶜout8); @test !is_trivial(fields_cpu.ᶜout8) + +end + +#! format: on +nothing diff --git a/test/Operators/finitedifference/utils_fd_ops_shared_memory.jl b/test/Operators/finitedifference/utils_fd_ops_shared_memory.jl new file mode 100644 index 0000000000..0963d4ec13 --- /dev/null +++ b/test/Operators/finitedifference/utils_fd_ops_shared_memory.jl @@ -0,0 +1,146 @@ +ENV["CLIMACOMMS_DEVICE"] = "CUDA"; # requires cuda +using LazyBroadcast: lazy +using ClimaCore.Utilities: half +using Test, ClimaComms +ClimaComms.@import_required_backends; +using ClimaCore: Geometry, Spaces, Fields, Operators, ClimaCore; +using ClimaCore.CommonSpaces; + +get_space_extruded(dev, FT) = ExtrudedCubedSphereSpace( + FT; + device = dev, + z_elem = 63, + z_min = 0, + z_max = 1, + radius = 10, + h_elem = 30, + n_quad_points = 4, + staggering = CellCenter(), +); + +get_space_column(dev, FT) = ColumnSpace( + FT; + device = dev, + z_elem = 10, + z_min = 0, + z_max = 1, + staggering = CellCenter(), +); + +function kernels!(fields) + (; f, ρ, ϕ) = fields + (; ᶜout1, ᶜout2, ᶜout3, ᶜout4, ᶜout5, ᶜout6, ᶜout7, ᶜout8) = fields + FT = Spaces.undertype(axes(ϕ)) + div_bcs = Operators.DivergenceF2C(; + bottom = Operators.SetValue(Geometry.Covariant3Vector(FT(100))), + top = Operators.SetValue(Geometry.Covariant3Vector(FT(10000))), + ) + div = Operators.DivergenceF2C() + ᶠwinterp = Operators.WeightedInterpolateC2F( + bottom = Operators.Extrapolate(), + top = Operators.Extrapolate(), + ) + @. ᶜout1 = div(Geometry.WVector(f)) + @. ᶜout2 = 0 + @. ᶜout2 += div(Geometry.WVector(f) * 2) + @. ᶜout3 = div(Geometry.WVector(ᶠwinterp(ϕ, ρ))) + + @. ᶜout4 = div_bcs(Geometry.WVector(f)) + @. ᶜout5 = 0 + @. ᶜout5 += div_bcs(Geometry.WVector(f) * 2) + @. ᶜout6 = div_bcs(Geometry.WVector(ᶠwinterp(ϕ, ρ))) + + # from the wild + Ic2f = Operators.InterpolateC2F(; top = Operators.Extrapolate()) + divf2c = Operators.DivergenceF2C(; + bottom = Operators.SetValue(Geometry.Covariant3Vector(FT(100000000))), + ) + # only upward component of divergence + @. ᶜout7 = divf2c(Geometry.WVector(Ic2f(ϕ))) # works + @. ᶜout8 = divf2c(Ic2f(Geometry.WVector(ϕ))) # breaks + return nothing +end; + +function get_fields(space::Operators.AllFaceFiniteDifferenceSpace) + FT = Spaces.undertype(space) + (; z) = Fields.coordinate_field(space) + nt = (; f = Fields.Field(FT, space)) + @. nt.f = sin(z) + return nt +end + +function get_fields(space::Operators.AllCenterFiniteDifferenceSpace) + FT = Spaces.undertype(space) + K = (ntuple(i -> Symbol("ᶜout$i"), 8)..., :ρ, :ϕ) + V = ntuple(i -> Fields.zeros(space), length(K)) + (; z) = Fields.coordinate_field(space) + nt = (; zip(K, V)...) + @. nt.ρ = sin(z) + @. nt.ϕ = sin(z) + return nt +end + +function compare_cpu_gpu(cpu, gpu; print_diff = true, C_best = 10) + # there are some odd errors that build up when run without debug / bounds checks: + space = axes(cpu) + are_boundschecks_forced = Base.JLOptions().check_bounds == 1 + absΔ = abs.(parent(cpu) .- Array(parent(gpu))) + B = + are_boundschecks_forced ? maximum(absΔ) <= 1000 * eps() : + maximum(absΔ) <= 10000000 * eps() + C = + are_boundschecks_forced ? count(x -> x <= 1000 * eps(), absΔ) : + count(x -> x <= 10000000 * eps(), absΔ) + if !B && print_diff + if space isa Spaces.FiniteDifferenceSpace + @show parent(cpu)[1:3] + @show parent(gpu)[1:3] + @show parent(cpu)[(end - 3):end] + @show parent(gpu)[(end - 3):end] + else + @show parent(cpu)[1:3, 1, 1, 1, end] + @show parent(gpu)[1:3, 1, 1, 1, end] + @show parent(cpu)[(end - 3):end, 1, 1, 1, end] + @show parent(gpu)[(end - 3):end, 1, 1, 1, end] + end + end + @test B + return B +end + +# This function is useful for debugging new cases. +function compare_cpu_gpu_incremental(cpu, gpu; print_diff = true, C_best = 10) + # there are some odd errors that build up when run without debug / bounds checks: + space = axes(cpu) + are_boundschecks_forced = Base.JLOptions().check_bounds == 1 + absΔ = abs.(parent(cpu) .- Array(parent(gpu))) + B = + are_boundschecks_forced ? maximum(absΔ) <= 1000 * eps() : + maximum(absΔ) <= 10000000 * eps() + C = + are_boundschecks_forced ? count(x -> x <= 1000 * eps(), absΔ) : + count(x -> x <= 10000000 * eps(), absΔ) + @test C ≥ C_best + if !(C_best == 10) + C > C_best && @show C_best + @test_broken C > C_best + end + if !B && print_diff + if space isa Spaces.FiniteDifferenceSpace + @show parent(cpu)[1:3] + @show parent(gpu)[1:3] + @show parent(cpu)[(end - 3):end] + @show parent(gpu)[(end - 3):end] + else + @show parent(cpu)[1:3, 1, 1, 1, end] + @show parent(gpu)[1:3, 1, 1, 1, end] + @show parent(cpu)[(end - 3):end, 1, 1, 1, end] + @show parent(gpu)[(end - 3):end, 1, 1, 1, end] + end + end + return true +end + +is_trivial(x) = length(parent(x)) == count(iszero, parent(x)) # Make sure we don't have a trivial solution + +nothing