Skip to content

Commit

Permalink
Fuse sponge operations, using LazyBroadcast
Browse files Browse the repository at this point in the history
Slim method arguments

Use null broadcasted
  • Loading branch information
charleskawczynski committed Jan 31, 2025
1 parent 989a900 commit 93ca5de
Show file tree
Hide file tree
Showing 8 changed files with 243 additions and 82 deletions.
2 changes: 2 additions & 0 deletions src/ClimaAtmos.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using NVTX
import LazyBroadcast
import Thermodynamics as TD

include("null_broadcasted.jl")
include("compat.jl")
include(joinpath("parameters", "Parameters.jl"))
import .Parameters as CAP
Expand Down Expand Up @@ -112,6 +113,7 @@ include(
)
include(joinpath("parameterized_tendencies", "sponge", "rayleigh_sponge.jl"))
include(joinpath("parameterized_tendencies", "sponge", "viscous_sponge.jl"))
include(joinpath("parameterized_tendencies", "sponge", "sponge_tendencies.jl"))
include(
joinpath(
"parameterized_tendencies",
Expand Down
100 changes: 100 additions & 0 deletions src/null_broadcasted.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# TODO: use https://github.com/CliMA/NullBroadcasts.jl when released

"""
NullBroadcasted()
A `Base.AbstractBroadcasted` that represents arithmetic object.
An `NullBroadcasted()` can be added to, subtracted from, or multiplied by any
value in a broadcast expression without incurring a runtime performance
penalty.
For example, the following rules hold when broadcasting instances of
`NullBroadcasted`:
```
1 + NullBroadcasted() == 1
NullBroadcasted() + 1 == 1
1 - NullBroadcasted() == 1
1 * NullBroadcasted() == NullBroadcasted()
1 / NullBroadcasted() == NullBroadcasted()
```
"""
struct NullBroadcasted <: Base.AbstractBroadcasted end
Base.broadcastable(x::NullBroadcasted) = x

struct NullBroadcastedStyle <: Base.BroadcastStyle end
Base.BroadcastStyle(::Type{<:NullBroadcasted}) = NullBroadcasted()

# Specialize on AbstractArrayStyle to avoid ambiguities with AbstractBroadcasted.
Base.BroadcastStyle(::NullBroadcasted, ::Base.Broadcast.AbstractArrayStyle) =
NullBroadcasted()
Base.BroadcastStyle(::Base.Broadcast.AbstractArrayStyle, ::NullBroadcasted) =
NullBroadcasted()

# Add another method to avoid ambiguity between the previous two.
Base.BroadcastStyle(::NullBroadcasted, ::NullBroadcasted) = NullBroadcasted()

broadcasted_sum(args) =
if isempty(args)
NullBroadcasted()
elseif length(args) == 1
args[1]
else
Base.broadcasted(+, args...)
end
Base.broadcasted(::NullBroadcasted, ::typeof(+), args...) =
broadcasted_sum(filter(arg -> !(arg isa NullBroadcasted), args))

Base.broadcasted(op::typeof(-), ::NullBroadcasted, arg) =
Base.broadcasted(op, arg)
Base.broadcasted(op::typeof(-), arg, ::NullBroadcasted) =
Base.broadcasted(Base.identity, arg)
Base.broadcasted(op::typeof(-), a::NullBroadcasted) = NullBroadcasted()
Base.broadcasted(op::typeof(-), a::NullBroadcasted, ::NullBroadcasted) =
Base.broadcasted(op, a)

Base.broadcasted(op::typeof(+), ::NullBroadcasted, args...) =
Base.broadcasted(op, args...)
Base.broadcasted(op::typeof(+), arg, ::NullBroadcasted, args...) =
Base.broadcasted(op, arg, args...)
Base.broadcasted(
op::typeof(+),
a::NullBroadcasted,
::NullBroadcasted,
args...,
) = Base.broadcasted(op, a, args...)

Base.broadcasted(op::typeof(*), ::NullBroadcasted, args...) = NullBroadcasted()
Base.broadcasted(op::typeof(*), arg, ::NullBroadcasted) = NullBroadcasted()
Base.broadcasted(op::typeof(*), ::NullBroadcasted, ::NullBroadcasted) =
NullBroadcasted()
Base.broadcasted(op::typeof(/), ::NullBroadcasted, args...) = NullBroadcasted()
Base.broadcasted(op::typeof(/), arg, ::NullBroadcasted) = NullBroadcasted()
Base.broadcasted(op::typeof(/), ::NullBroadcasted, ::NullBroadcasted) =
NullBroadcasted()

Base.broadcasted(op::typeof(identity), a::NullBroadcasted) = a

function skip_materialize(dest, bc::Base.Broadcast.Broadcasted)
if typeof(bc.f) <: typeof(+) || typeof(bc.f) <: typeof(-)
if length(bc.args) == 2 &&
bc.args[1] === dest &&
bc.args[2] === Base.Broadcast.Broadcasted(NullBroadcasted, ())
return true
else
return false
end
else
return false
end
end

Base.Broadcast.instantiate(
bc::Base.Broadcast.Broadcasted{NullBroadcastedStyle},
) = x

Base.Broadcast.materialize!(dest, x::NullBroadcasted) =
error("NullBroadcasted objects cannot be materialized.")
Base.Broadcast.materialize(dest, x::NullBroadcasted) =
error("NullBroadcasted objects cannot be materialized.")
12 changes: 6 additions & 6 deletions src/parameterized_tendencies/sponge/rayleigh_sponge.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
##### Rayleigh sponge
#####

import LazyBroadcast: @lazy
import ClimaCore.Fields as Fields

rayleigh_sponge_tendency!(Yₜ, Y, p, t, ::Nothing) = nothing

αₘ(s::RayleighSponge{FT}, z, α) where {FT} = ifelse(z > s.zd, α, FT(0))
ζ_rayleigh(s::RayleighSponge{FT}, z, zmax) where {FT} =
sin(FT(π) / 2 * (z - s.zd) / (zmax - s.zd))^2
Expand All @@ -14,8 +13,9 @@ rayleigh_sponge_tendency!(Yₜ, Y, p, t, ::Nothing) = nothing
β_rayleigh_w(s::RayleighSponge{FT}, z, zmax) where {FT} =
αₘ(s, z, s.α_w) * ζ_rayleigh(s, z, zmax)

function rayleigh_sponge_tendency!(Yₜ, Y, p, t, s::RayleighSponge)
ᶜz = Fields.coordinate_field(Y.c).z
zmax = z_max(axes(Y.f))
@. Yₜ.c.uₕ -= β_rayleigh_uₕ(s, ᶜz, zmax) * Y.c.uₕ
function rayleigh_sponge_tendency_uₕ(ᶜuₕ, s)
s isa Nothing && return NullBroadcasted()
(; ᶜz, ᶠz) = z_coordinate_fields(axes(ᶜuₕ))
zmax = z_max(axes(ᶠz))
return @lazy @. -β_rayleigh_uₕ(s, ᶜz, zmax) * ᶜuₕ
end
44 changes: 44 additions & 0 deletions src/parameterized_tendencies/sponge/sponge_tendencies.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#####
##### Sponge tendencies
#####

import LazyBroadcast: @lazy
import ClimaCore.Fields as Fields
import ClimaCore.Geometry as Geometry
import ClimaCore.Spaces as Spaces

function z_coordinate_fields(space::Spaces.AbstractSpace)
ᶜz = Fields.coordinate_field(Spaces.center_space(space)).z
ᶠz = Fields.coordinate_field(Spaces.face_space(space)).z
return (; ᶜz, ᶠz)
end


function sponge_tendencies!(Yₜ, Y, p, t)
rs, vs = p.atmos.rayleigh_sponge, p.atmos.viscous_sponge
(; ᶜh_tot, ᶜspecific) = p.precomputed
ᶜuₕ = Y.c.uₕ
ᶠu₃ = Yₜ.f.u₃
ᶜρ = Y.c.ρ
vst_uₕ = viscous_sponge_tendency_uₕ(ᶜuₕ, vs)
vst_u₃ = viscous_sponge_tendency_u₃(ᶠu₃, vs)
vst_ρe_tot = viscous_sponge_tendency_ρe_tot(ᶜρ, ᶜh_tot, vs)
rst_uₕ = rayleigh_sponge_tendency_uₕ(ᶜuₕ, rs)

# TODO: fuse, once we fix
# https://github.com/CliMA/ClimaCore.jl/issues/2165
@. Yₜ.c.uₕ += vst_uₕ
@. Yₜ.c.uₕ += rst_uₕ
@. Yₜ.f.u₃.components.data.:1 += vst_u₃
@. Yₜ.c.ρe_tot += vst_ρe_tot

# TODO: can we write this out explicitly?
if vs isa ViscousSponge
for (ᶜρχₜ, ᶜχ, χ_name) in matching_subfields(Yₜ.c, ᶜspecific)
χ_name == :e_tot && continue
vst_tracer = viscous_sponge_tendency_tracer(ᶜρ, ᶜχ, vs)
@. ᶜρχₜ += vst_tracer
@. Yₜ.c.ρ += vst_tracer
end
end
end
51 changes: 29 additions & 22 deletions src/parameterized_tendencies/sponge/viscous_sponge.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,39 +2,46 @@
##### Viscous sponge
#####

import LazyBroadcast: @lazy
import ClimaCore.Fields as Fields
import ClimaCore.Geometry as Geometry
import ClimaCore.Spaces as Spaces


viscous_sponge_tendency!(Yₜ, Y, p, t, ::Nothing) = nothing

αₘ(s::ViscousSponge{FT}, z) where {FT} = ifelse(z > s.zd, s.κ₂, FT(0))
ζ_viscous(s::ViscousSponge{FT}, z, zmax) where {FT} =
sin(FT(π) / 2 * (z - s.zd) / (zmax - s.zd))^2
β_viscous(s::ViscousSponge{FT}, z, zmax) where {FT} =
αₘ(s, z) * ζ_viscous(s, z, zmax)

function viscous_sponge_tendency!(Yₜ, Y, p, t, s::ViscousSponge)
(; ᶜh_tot, ᶜspecific) = p.precomputed
ᶜuₕ = Y.c.uₕ
ᶜz = Fields.coordinate_field(Y.c).z
ᶠz = Fields.coordinate_field(Y.f).z
function viscous_sponge_tendency_uₕ(ᶜuₕ, s)
s isa Nothing && return NullBroadcasted()
(; ᶜz, ᶠz) = z_coordinate_fields(axes(ᶜuₕ))
zmax = z_max(axes(ᶠz))
@. Yₜ.c.uₕ +=
β_viscous(s, ᶜz, zmax) * (
wgradₕ(divₕ(ᶜuₕ)) - Geometry.project(
Geometry.Covariant12Axis(),
wcurlₕ(Geometry.project(Geometry.Covariant3Axis(), curlₕ(ᶜuₕ))),
)
return @lazy @. β_viscous(s, ᶜz, zmax) * (
wgradₕ(divₕ(ᶜuₕ)) - Geometry.project(
Geometry.Covariant12Axis(),
wcurlₕ(Geometry.project(Geometry.Covariant3Axis(), curlₕ(ᶜuₕ))),
)
@. Yₜ.f.u₃.components.data.:1 +=
β_viscous(s, ᶠz, zmax) * wdivₕ(gradₕ(Y.f.u₃.components.data.:1))
)
end

function viscous_sponge_tendency_u₃(u₃, s)
s isa Nothing && return NullBroadcasted()
(; ᶠz) = z_coordinate_fields(axes(u₃))
zmax = z_max(axes(ᶠz))
return @lazy @. β_viscous(s, ᶠz, zmax) * wdivₕ(gradₕ(u₃.components.data.:1))
end

@. Yₜ.c.ρe_tot += β_viscous(s, ᶜz, zmax) * wdivₕ(Y.c.ρ * gradₕ(ᶜh_tot))
for (ᶜρχₜ, ᶜχ, χ_name) in matching_subfields(Yₜ.c, ᶜspecific)
χ_name == :e_tot && continue
@. ᶜρχₜ += β_viscous(s, ᶜz, zmax) * wdivₕ(Y.c.ρ * gradₕ(ᶜχ))
@. Yₜ.c.ρ += β_viscous(s, ᶜz, zmax) * wdivₕ(Y.c.ρ * gradₕ(ᶜχ))
end
function viscous_sponge_tendency_ρe_tot(ᶜρ, ᶜh_tot, s)
s isa Nothing && return NullBroadcasted()
(; ᶜz, ᶠz) = z_coordinate_fields(axes(ᶜρ))
zmax = z_max(axes(ᶠz))
return @lazy @. β_viscous(s, ᶜz, zmax) * wdivₕ(ᶜρ * gradₕ(ᶜh_tot))
end

function viscous_sponge_tendency_tracer(ᶜρ, ᶜχ, s)
s isa Nothing && return NullBroadcasted()
(; ᶜz, ᶠz) = z_coordinate_fields(axes(ᶜρ))
zmax = z_max(axes(ᶠz))
return @lazy @. β_viscous(s, ᶜz, zmax) * wdivₕ(ᶜρ * gradₕ(ᶜχ))
end
4 changes: 1 addition & 3 deletions src/prognostic_equations/remaining_tendency.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,8 @@ NVTX.@annotate function remaining_tendency!(Yₜ, Yₜ_lim, Y, p, t)
end

NVTX.@annotate function additional_tendency!(Yₜ, Y, p, t)
viscous_sponge_tendency!(Yₜ, Y, p, t, p.atmos.viscous_sponge)

sponge_tendencies!(Yₜ, Y, p, t)
# Vertical tendencies
rayleigh_sponge_tendency!(Yₜ, Y, p, t, p.atmos.rayleigh_sponge)
forcing_tendency!(Yₜ, Y, p, t, p.atmos.forcing_type)
subsidence_tendency!(Yₜ, Y, p, t, p.atmos.subsidence)
edmf_coriolis_tendency!(Yₜ, Y, p, t, p.atmos.edmf_coriolis)
Expand Down
59 changes: 33 additions & 26 deletions test/parameterized_tendencies/sponge/rayleigh_sponge.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,37 +5,44 @@ using Revise; include("test/parameterized_tendencies/sponge/rayleigh_sponge.jl")
using ClimaComms
ClimaComms.@import_required_backends
import ClimaAtmos as CA
import SurfaceFluxes as SF
import ClimaAtmos.Parameters as CAP
import ClimaCore as CC
using ClimaCore.CommonSpaces
using ClimaCore: Spaces, Fields, Geometry, ClimaCore
using Test
using Base.Broadcast: materialize

pkgversion(ClimaCore) < v"0.14.20" && exit() # CommonSpaces
using ClimaCore.CommonSpaces

include("../../test_helpers.jl")
### Common Objects ###
@testset "Rayleigh-sponge functions" begin
### Boilerplate default integrator objects
config = CA.AtmosConfig(
Dict("initial_condition" => "DryBaroclinicWave");
job_id = "sponge1",
FT = Float64
ᶜspace = ExtrudedCubedSphereSpace(
FT;
z_elem = 10,
z_min = 0,
z_max = 1,
radius = 10,
h_elem = 10,
n_quad_points = 4,
staggering = CellCenter(),
)
(; Y) = generate_test_simulation(config)
zmax = maximum(CC.Fields.coordinate_field(Y.f).z)
z = CC.Fields.coordinate_field(Y.c).z
Y.c.uₕ.components.data.:1 .= ones(axes(Y.c))
Y.c.uₕ.components.data.:2 .= ones(axes(Y.c))
FT = eltype(Y)
ᶜYₜ = zero(Y)
ᶠspace = Spaces.face_space(ᶜspace)
ᶠz = Fields.coordinate_field(ᶠspace).z
ᶜz = Fields.coordinate_field(ᶜspace).z
zmax = maximum(ᶠz)
ᶜuₕ = map(z -> zero(Geometry.Covariant12Vector{eltype(z)}), ᶜz)
@. ᶜuₕ.components.data.:1 = 1
@. ᶜuₕ.components.data.:2 = 1
### Component test begins here
rs = CA.RayleighSponge(; zd = FT(0), α_uₕ = FT(1), α_w = FT(1))
@test CA.β_rayleigh_uₕ.(rs, z, zmax) == @. sin(FT(π) / 2 * z / zmax)^2
CA.rayleigh_sponge_tendency!(ᶜYₜ, Y, nothing, FT(0), rs)
# Test that only required tendencies are affected
for (var_name) in filter(x -> (x != :uₕ), propertynames(Y.c))
@test ᶜYₜ.c.:($var_name) == zeros(axes(Y.c))
end
for (var_name) in propertynames(Y.f)
@test ᶜYₜ.f.:($var_name) == zeros(axes(Y.f))
end
@test ᶜYₜ.c.uₕ.components.data.:1 == -1 .* (CA.β_rayleigh_uₕ.(rs, z, zmax))
@test ᶜYₜ.c.uₕ.components.data.:2 == -1 .* (CA.β_rayleigh_uₕ.(rs, z, zmax))
expected = @. sin(FT(π) / 2 * ᶜz / zmax)^2
computed = CA.rayleigh_sponge_tendency_uₕ(ᶜuₕ, rs)
@test CA.β_rayleigh_uₕ.(rs, ᶜz, zmax) == expected
@test materialize(computed) == .-expected .* ᶜuₕ

# Test when not using a Rayleigh sponge.
computed = CA.rayleigh_sponge_tendency_uₕ(ᶜuₕ, nothing)
expected = @. ᶜuₕ .* 0
@test eltype(computed) == eltype(expected)
@. ᶜuₕ += computed # test that it can broadcast
end
53 changes: 28 additions & 25 deletions test/parameterized_tendencies/sponge/viscous_sponge.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,33 @@ using Revise; include("test/parameterized_tendencies/sponge/viscous_sponge.jl")
using ClimaComms
ClimaComms.@import_required_backends
import ClimaAtmos as CA
import ClimaCore
using ClimaCore: Spaces, Grids, Fields
if pkgversion(ClimaCore) v"0.14.20"
using ClimaCore.CommonGrids
using Test
using ClimaCore.CommonSpaces
using ClimaCore: Spaces, Fields, Geometry, ClimaCore
using Test
using Base.Broadcast: materialize

### Common Objects ###
@testset "Viscous-sponge functions" begin
grid = ExtrudedCubedSphereGrid(;
z_elem = 10,
z_min = 0,
z_max = 1,
radius = 10,
h_elem = 10,
n_quad_points = 4,
)
cspace = Spaces.ExtrudedFiniteDifferenceSpace(grid, Grids.CellCenter())
fspace = Spaces.FaceExtrudedFiniteDifferenceSpace(cspace)
z = Fields.coordinate_field(cspace).z
zmax = maximum(Fields.coordinate_field(fspace).z)
FT = typeof(zmax)
### Component test begins here
s = CA.ViscousSponge{FT}(; zd = 0, κ₂ = 1)
@test CA.β_viscous.(s, z, zmax) == @. ifelse(z > s.zd, s.κ₂, FT(0)) *
sin(FT(π) / 2 * (z - s.zd) / (zmax - s.zd))^2
end
pkgversion(ClimaCore) < v"0.14.20" && exit() # CommonSpaces
using ClimaCore.CommonSpaces

### Common Objects ###
@testset "Viscous-sponge functions" begin
FT = Float64
ᶜspace = ExtrudedCubedSphereSpace(
FT;
z_elem = 10,
z_min = 0,
z_max = 1,
radius = 10,
h_elem = 10,
n_quad_points = 4,
staggering = CellCenter(),
)
ᶠspace = Spaces.face_space(ᶜspace)
ᶜz = Fields.coordinate_field(ᶜspace).z
ᶠz = Fields.coordinate_field(ᶠspace).z
zmax = maximum(ᶠz)
### Component test begins here
s = CA.ViscousSponge{FT}(; zd = 0, κ₂ = 1)
@test CA.β_viscous.(s, ᶜz, zmax) == @. ifelse(ᶜz > s.zd, s.κ₂, FT(0)) *
sin(FT(π) / 2 * (ᶜz - s.zd) / (zmax - s.zd))^2
end

0 comments on commit 93ca5de

Please sign in to comment.