Skip to content

Commit

Permalink
Merge pull request #2136 from CliMA/dy/deep_dss
Browse files Browse the repository at this point in the history
 Clarify DSS weights and fix DSS buffer bug
  • Loading branch information
dennisYatunin authored Jan 28, 2025
2 parents cee1557 + 590b1c9 commit 2623ded
Show file tree
Hide file tree
Showing 15 changed files with 120 additions and 99 deletions.
12 changes: 6 additions & 6 deletions ext/cuda/topologies_dss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import ClimaCore: DataLayouts, Topologies, Spaces, Fields
import ClimaCore.DataLayouts: CartesianFieldIndex
using CUDA
import ClimaCore.Topologies
import ClimaCore.Topologies: DSSDataTypes, DSSPerimeterTypes, DSSWeightTypes
import ClimaCore.Topologies: DSSDataTypes, DSSPerimeterTypes
import ClimaCore.Topologies: perimeter_vertex_node_index

_max_threads_cuda() = 256
Expand Down Expand Up @@ -198,7 +198,7 @@ function Topologies.dss_transform!(
data::DSSDataTypes,
perimeter::Topologies.Perimeter2D,
local_geometry::DSSDataTypes,
weight::DSSWeightTypes,
dss_weights::DSSDataTypes,
localelems::AbstractVector{Int},
)
nlocalelems = length(localelems)
Expand All @@ -214,7 +214,7 @@ function Topologies.dss_transform!(
data,
perimeter,
local_geometry,
weight,
dss_weights,
localelems,
Val(nlocalelems),
)
Expand All @@ -231,7 +231,7 @@ function Topologies.dss_transform!(
data,
perimeter,
local_geometry,
weight,
dss_weights,
localelems,
)
end
Expand All @@ -243,7 +243,7 @@ function dss_transform_kernel!(
data::DSSDataTypes,
perimeter::Topologies.Perimeter2D,
local_geometry::DSSDataTypes,
weight::DSSWeightTypes,
dss_weights::DSSDataTypes,
localelems::AbstractVector{Int},
::Val{nlocalelems},
) where {nlocalelems}
Expand All @@ -260,7 +260,7 @@ function dss_transform_kernel!(
src = Topologies.dss_transform(
data[loc],
local_geometry[loc],
weight[loc],
dss_weights[loc],
)
perimeter_data[CI(p, 1, 1, level, elem)] =
Topologies.drop_vert_dim(eltype(perimeter_data), src)
Expand Down
7 changes: 2 additions & 5 deletions src/Fields/Fields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -474,11 +474,8 @@ end
Create a buffer for communicating neighbour information of `field`.
"""
function Spaces.create_dss_buffer(field::Field)
space = axes(field)
hspace = Spaces.horizontal_space(space)
Spaces.create_dss_buffer(field_values(field), hspace)
end
Spaces.create_dss_buffer(field::Field) =
Spaces.create_dss_buffer(field_values(field), axes(field))

Base.@propagate_inbounds function level(
field::Union{
Expand Down
2 changes: 1 addition & 1 deletion src/Grids/Grids.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ function local_geometry_type end
# Fallback, but this requires user error-handling
local_geometry_type(::Type{T}) where {T} = Union{}

function local_dss_weights end
function dss_weights end
function quadrature_style end
function vertical_topology end

Expand Down
7 changes: 5 additions & 2 deletions src/Grids/extruded.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,11 @@ topology(grid::ExtrudedFiniteDifferenceGrid) = topology(grid.horizontal_grid)
vertical_topology(grid::ExtrudedFiniteDifferenceGrid) =
topology(grid.vertical_grid)

local_dss_weights(grid::ExtrudedFiniteDifferenceGrid) =
local_dss_weights(grid.horizontal_grid)
# Since ∂z/∂ξ₃ and r are continuous across element boundaries, we can reuse
# the horizontal weights instead of calling compute_dss_weights on the
# extruded local geometry.
dss_weights(grid::AbstractExtrudedFiniteDifferenceGrid, ::Staggering) =
dss_weights(grid.horizontal_grid, nothing)

local_geometry_data(grid::AbstractExtrudedFiniteDifferenceGrid, ::CellCenter) =
grid.center_local_geometry
Expand Down
5 changes: 4 additions & 1 deletion src/Grids/level.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ level(

topology(levelgrid::LevelGrid) = topology(levelgrid.full_grid)

local_dss_weights(grid::LevelGrid) = local_dss_weights(grid.full_grid)
dss_weights(levelgrid::LevelGrid{<:Any, Int}, ::Nothing) =
level(dss_weights(levelgrid.full_grid, CellCenter()), levelgrid.level)
dss_weights(levelgrid::LevelGrid{<:Any, PlusHalf{Int}}, ::Nothing) =
level(dss_weights(levelgrid.full_grid, CellFace()), levelgrid.level + half)

local_geometry_type(::Type{LevelGrid{G, L}}) where {G, L} =
local_geometry_type(G)
Expand Down
40 changes: 22 additions & 18 deletions src/Grids/spectralelement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,17 +102,13 @@ function _SpectralElementGrid1D(
)
end
end
dss_weights = copy(local_geometry.J)
dss_weights .= one(FT)
Topologies.dss_1d!(topology, dss_weights)
dss_weights = one(FT) ./ dss_weights

return SpectralElementGrid1D(
topology,
quadrature_style,
global_geometry,
local_geometry,
dss_weights,
compute_dss_weights(local_geometry, topology, quadrature_style),
)
end

Expand All @@ -136,7 +132,7 @@ mutable struct SpectralElementGrid2D{
quadrature_style::Q
global_geometry::GG
local_geometry::LG
local_dss_weights::D
dss_weights::D
internal_surface_geometry::IS
boundary_surface_geometries::BS
enable_bubble::Bool
Expand Down Expand Up @@ -418,14 +414,6 @@ function _SpectralElementGrid2D(
end
end

# dss_weights = J ./ dss(J)
J = DataLayouts.rebuild(local_geometry.J, DA)
dss_local_weights = copy(J)
if quadrature_style isa Quadratures.GLL
Topologies.dss!(dss_local_weights, topology)
end
dss_local_weights .= J ./ dss_local_weights

SG = Geometry.SurfaceGeometry{
FT,
Geometry.AxisVector{FT, Geometry.LocalAxis{AIdx}, SVector{2, FT}},
Expand Down Expand Up @@ -497,12 +485,14 @@ function _SpectralElementGrid2D(
internal_surface_geometry = nothing
boundary_surface_geometries = nothing
end

device_local_geometry = DataLayouts.rebuild(local_geometry, DA)
return SpectralElementGrid2D(
topology,
quadrature_style,
global_geometry,
DataLayouts.rebuild(local_geometry, DA),
dss_local_weights,
device_local_geometry,
compute_dss_weights(device_local_geometry, topology, quadrature_style),
internal_surface_geometry,
boundary_surface_geometries,
enable_bubble,
Expand Down Expand Up @@ -578,6 +568,21 @@ function compute_surface_geometry(
return Geometry.SurfaceGeometry(sWJ, n)
end

function compute_dss_weights(local_geometry, topology, quadrature_style)
is_dss_required =
Quadratures.unique_degrees_of_freedom(quadrature_style) <
Quadratures.degrees_of_freedom(quadrature_style)
# Although the weights are defined as WJ / Σ collocated WJ, we can use J
# instead of WJ if the weights are symmetric across element boundaries.
dss_weights = copy(local_geometry.J)
if topology isa Topologies.IntervalTopology
is_dss_required && Topologies.dss_1d!(topology, dss_weights)
else
is_dss_required && Topologies.dss!(dss_weights, topology)
end
@. dss_weights = local_geometry.J / dss_weights
return dss_weights
end

# accessors

Expand All @@ -588,8 +593,7 @@ local_geometry_data(grid::AbstractSpectralElementGrid, ::Nothing) =
global_geometry(grid::AbstractSpectralElementGrid) = grid.global_geometry

quadrature_style(grid::AbstractSpectralElementGrid) = grid.quadrature_style
local_dss_weights(grid::SpectralElementGrid1D) = grid.dss_weights
local_dss_weights(grid::SpectralElementGrid2D) = grid.local_dss_weights
dss_weights(grid::AbstractSpectralElementGrid, ::Nothing) = grid.dss_weights

## GPU compatibility
struct DeviceSpectralElementGrid2D{Q, GG, LG} <: AbstractSpectralElementGrid
Expand Down
35 changes: 13 additions & 22 deletions src/Quadratures/Quadratures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,29 @@ export QuadratureStyle,
GLL, GL, polynomial_degree, degrees_of_freedom, quadrature_points

"""
QuadratureStyle
QuadratureStyle{Nq}
Quadrature style supertype. See sub-types:
- [`GLL`](@ref)
- [`GL`](@ref)
- [`Uniform`](@ref)
"""
abstract type QuadratureStyle end
abstract type QuadratureStyle{Nq} end

"""
polynomial_degree(QuadratureStyle) -> Int
Returns the polynomial degree of the `QuadratureStyle` concrete type
"""
function polynomial_degree end
@inline polynomial_degree(::QuadratureStyle{Nq}) where {Nq} = Nq - 1


"""
degrees_of_freedom(QuadratureStyle) -> Int
Returns the degrees_of_freedom of the `QuadratureStyle` concrete type
"""
function degrees_of_freedom end
@inline degrees_of_freedom(::QuadratureStyle{Nq}) where {Nq} = Nq

"""
points, weights = quadrature_points(::Type{FT}, quadrature_style)
Expand All @@ -46,15 +46,12 @@ function quadrature_points end
Gauss-Legendre-Lobatto quadrature using `Nq` quadrature points.
"""
struct GLL{Nq} <: QuadratureStyle end
struct GLL{Nq} <: QuadratureStyle{Nq} end

Base.show(io::IO, ::GLL{Nq}) where {Nq} =
print(io, Nq, "-point Gauss-Legendre-Lobatto quadrature")

@inline polynomial_degree(::GLL{Nq}) where {Nq} = Int(Nq - 1)
@inline degrees_of_freedom(::GLL{Nq}) where {Nq} = Int(Nq)
unique_degrees_of_freedom(::GLL{Nq}) where {Nq} = Int(Nq - 1)

unique_degrees_of_freedom(::GLL{Nq}) where {Nq} = Nq - 1
@generated function quadrature_points(::Type{FT}, ::GLL{Nq}) where {FT, Nq}
points, weights = GaussQuadrature.legendre(FT, Nq, GaussQuadrature.both)
:($(SVector{Nq}(points)), $(SVector{Nq}(weights)))
Expand All @@ -65,14 +62,12 @@ end
Gauss-Legendre quadrature using `Nq` quadrature points.
"""
struct GL{Nq} <: QuadratureStyle end
struct GL{Nq} <: QuadratureStyle{Nq} end

Base.show(io::IO, ::GL{Nq}) where {Nq} =
print(io, Nq, "-point Gauss-Legendre quadrature")

@inline polynomial_degree(::GL{Nq}) where {Nq} = Int(Nq - 1)
@inline degrees_of_freedom(::GL{Nq}) where {Nq} = Int(Nq)
unique_degrees_of_freedom(::GL{Nq}) where {Nq} = Int(Nq)

unique_degrees_of_freedom(::GL{Nq}) where {Nq} = Nq
@generated function quadrature_points(::Type{FT}, ::GL{Nq}) where {FT, Nq}
points, weights = GaussQuadrature.legendre(FT, Nq, GaussQuadrature.neither)
:($(SVector{Nq}(points)), $(SVector{Nq}(weights)))
Expand All @@ -83,11 +78,9 @@ end
Uniformly-spaced quadrature.
"""
struct Uniform{Nq} <: QuadratureStyle end

@inline polynomial_degree(::Uniform{Nq}) where {Nq} = Int(Nq - 1)
@inline degrees_of_freedom(::Uniform{Nq}) where {Nq} = Int(Nq)
struct Uniform{Nq} <: QuadratureStyle{Nq} end

unique_degrees_of_freedom(::Uniform{Nq}) where {Nq} = Nq
@generated function quadrature_points(::Type{FT}, ::Uniform{Nq}) where {FT, Nq}
points = SVector{Nq}(range(-1 + FT(1 / Nq), step = FT(2 / Nq), length = Nq))
weights = SVector{Nq}(ntuple(i -> FT(2 / Nq), Nq))
Expand All @@ -99,11 +92,9 @@ end
Uniformly-spaced quadrature including boundary.
"""
struct ClosedUniform{Nq} <: QuadratureStyle end

@inline polynomial_degree(::ClosedUniform{Nq}) where {Nq} = Int(Nq - 1)
@inline degrees_of_freedom(::ClosedUniform{Nq}) where {Nq} = Int(Nq)
struct ClosedUniform{Nq} <: QuadratureStyle{Nq} end

unique_degrees_of_freedom(::ClosedUniform{Nq}) where {Nq} = Nq - 1
@generated function quadrature_points(
::Type{FT},
::ClosedUniform{Nq},
Expand Down
3 changes: 2 additions & 1 deletion src/Spaces/Spaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import ..Grids:
local_geometry_type,
local_geometry_data,
global_geometry,
local_dss_weights,
dss_weights,
quadrature_style

import ClimaComms
Expand Down Expand Up @@ -69,6 +69,7 @@ vertical_topology(space::AbstractSpace) = vertical_topology(grid(space))

local_geometry_data(space::AbstractSpace) =
local_geometry_data(grid(space), staggering(space))
dss_weights(space::AbstractSpace) = dss_weights(grid(space), staggering(space))

function n_elements_per_panel_direction(space::AbstractSpace)
hspace = Spaces.horizontal_space(space)
Expand Down
24 changes: 10 additions & 14 deletions src/Spaces/dss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ import ..Topologies:
load_from_recv_buffer!,
DSSTypesAll,
DSSDataTypes,
DSSPerimeterTypes,
DSSWeightTypes
DSSPerimeterTypes


perimeter(space::AbstractSpectralElementSpace) = Topologies.Perimeter2D(
Expand All @@ -23,10 +22,7 @@ perimeter(space::AbstractSpectralElementSpace) = Topologies.Perimeter2D(


"""
create_dss_buffer(
data::Union{DataLayouts.IJFH, DataLayouts.VIJFH},
hspace::AbstractSpectralElementSpace,
)
create_dss_buffer(data, space)
Creates a [`DSSBuffer`](@ref) for the field data corresponding to `data`
"""
Expand All @@ -37,13 +33,13 @@ function create_dss_buffer(
DataLayouts.VIJFH,
DataLayouts.VIJHF,
},
hspace::SpectralElementSpace2D,
space,
)
create_dss_buffer(
data,
topology(hspace),
local_geometry_data(hspace),
local_dss_weights(hspace),
topology(space),
local_geometry_data(space),
dss_weights(space),
)
end

Expand All @@ -54,7 +50,7 @@ function create_dss_buffer(
DataLayouts.VIFH,
DataLayouts.VIHF,
},
hspace::SpectralElementSpace1D,
space,
)
nothing
end
Expand Down Expand Up @@ -122,7 +118,7 @@ function weighted_dss_prepare!(
dss_buffer,
data,
local_geometry_data(space),
local_dss_weights(hspace),
dss_weights(space),
Spaces.perimeter(hspace),
dss_buffer.perimeter_elems,
)
Expand Down Expand Up @@ -236,7 +232,7 @@ function weighted_dss_internal!(
topology(hspace),
data,
local_geometry_data(space),
local_dss_weights(space),
dss_weights(space),
)
else
device = ClimaComms.device(topology(hspace))
Expand All @@ -245,7 +241,7 @@ function weighted_dss_internal!(
dss_buffer,
data,
local_geometry_data(space),
local_dss_weights(space),
dss_weights(space),
Spaces.perimeter(hspace),
dss_buffer.internal_elems,
)
Expand Down
4 changes: 0 additions & 4 deletions src/Spaces/extruded.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,6 @@ FaceExtrudedFiniteDifferenceSpace(space::ExtrudedFiniteDifferenceSpace) =
CenterExtrudedFiniteDifferenceSpace(space::ExtrudedFiniteDifferenceSpace) =
ExtrudedFiniteDifferenceSpace(grid(space), CellCenter())


local_dss_weights(space::ExtrudedFiniteDifferenceSpace) =
local_dss_weights(grid(space))

staggering(space::ExtrudedFiniteDifferenceSpace) = getfield(space, :staggering)
grid(space::ExtrudedFiniteDifferenceSpace) = getfield(space, :grid)
space(space::ExtrudedFiniteDifferenceSpace, staggering::Staggering) =
Expand Down
Loading

0 comments on commit 2623ded

Please sign in to comment.