Skip to content

Commit

Permalink
Remove Infiltrator from test suite. (#541)
Browse files Browse the repository at this point in the history
  • Loading branch information
sriharshakandala authored Sep 20, 2024
1 parent cd5b652 commit 0c2a8c6
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 116 deletions.
91 changes: 17 additions & 74 deletions perf/benchmark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,14 @@ import ClimaComms
@static pkgversion(ClimaComms) >= v"0.6" && ClimaComms.@import_required_backends
import Logging

context = ClimaComms.context()
@info "------------------------------------------------- Benchmark: gray_atm"
@suppress_out begin
include(joinpath(root_dir, "test", "gray_atm_utils.jl"))
gray_atmos_lw_equil(ClimaComms.context(), NoScatLWRTE, FT; exfiltrate = true)
end
(; slv_lw, gray_as) = Infiltrator.exfiltrated
include(joinpath(root_dir, "test", "gray_atm_utils.jl"))
gray_as, slv_lw, _ = setup_gray_atmos_lw_equil_test(context, NoScatLWRTE, FT)

@info "gray_atm lw"
solve_lw!(slv_lw, gray_as) # compile first
device = ClimaComms.device(ClimaComms.context())
device = ClimaComms.device(context)
trial = if device isa ClimaComms.CUDADevice
using CUDA
@benchmark CUDA.@sync solve_lw!($slv_lw, $gray_as)
Expand All @@ -32,11 +31,10 @@ end
show(stdout, MIME("text/plain"), trial)
println()

gray_atmos_sw_test(ClimaComms.context(), NoScatSWRTE, FT, 1; exfiltrate = true)
(; slv_sw, as) = Infiltrator.exfiltrated
as, slv_sw, _ = setup_gray_atmos_sw_test(context, NoScatSWRTE, FT, 1)

solve_sw!(slv_sw, as) # compile first
@info "gray_atm sw"
device = ClimaComms.device(ClimaComms.context())
trial = if device isa ClimaComms.CUDADevice
using CUDA
@benchmark CUDA.@sync solve_sw!($slv_sw, $as)
Expand All @@ -48,28 +46,17 @@ println()
@info "------------------------------------------------- Benchmark: clear_sky"
# @suppress_out begin
include(joinpath(root_dir, "test", "clear_sky_utils.jl"))
context = ClimaComms.context()

toler_lw_noscat = Dict(Float64 => Float64(1e-4), Float32 => Float32(0.04))
toler_lw_2stream = Dict(Float64 => Float64(4.5), Float32 => Float32(4.5))
toler_sw = Dict(Float64 => Float64(1e-3), Float32 => Float32(0.04))

clear_sky(
ClimaComms.context(),
TwoStreamLWRTE,
TwoStreamSWRTE,
VmrGM,
FT,
toler_lw_2stream,
toler_sw;
exfiltrate = true,
)
device, as, lookup_lw, lookup_sw, slv_lw, slv_sw =
setup_clear_sky_test(context, TwoStreamLWRTE, TwoStreamSWRTE, VmrGM, FT, toler_lw_2stream, toler_sw)
# end
(; slv_lw, slv_sw, as, lookup_sw, lookup_lw) = Infiltrator.exfiltrated

@info "clear_sky lw"
solve_lw!(slv_lw, as, lookup_lw) # compile first
device = ClimaComms.device(ClimaComms.context())
trial = if device isa ClimaComms.CUDADevice
using CUDA
@benchmark CUDA.@sync solve_lw!($slv_lw, $as, $lookup_lw)
Expand All @@ -81,7 +68,6 @@ println()

@info "clear_sky sw"
solve_sw!(slv_sw, as, lookup_sw) # compile first
device = ClimaComms.device(ClimaComms.context())
trial = if device isa ClimaComms.CUDADevice
using CUDA
@benchmark CUDA.@sync solve_sw!($slv_sw, $as, $lookup_sw)
Expand All @@ -98,21 +84,20 @@ include(joinpath(root_dir, "test", "all_sky_utils.jl"))
toler_lw_noscat = Dict(Float64 => Float64(1e-5), Float32 => Float32(0.05))
toler_lw_2stream = Dict(Float64 => Float64(5), Float32 => Float32(5))
toler_sw = Dict(Float64 => Float64(1e-5), Float32 => Float32(0.06))

all_sky(
use_lut = true
cldfrac = FT(1)
ncol = 128
device, as, lookup_lw, lookup_lw_cld, lookup_sw, lookup_sw_cld, slv_lw, slv_sw, _ = setup_all_sky_test(
ClimaComms.context(),
TwoStreamLWRTE,
TwoStreamSWRTE,
FT,
toler_lw_2stream,
toler_sw;
use_lut = true,
cldfrac = FT(1),
exfiltrate = true,
toler_sw,
ncol,
use_lut,
cldfrac,
)
# end

(; slv_lw, slv_sw, as, lookup_sw, lookup_sw_cld, lookup_lw, lookup_lw_cld) = Infiltrator.exfiltrated

solve_sw!(slv_sw, as, lookup_sw, lookup_sw_cld) # compile first
solve_lw!(slv_lw, as, lookup_lw, lookup_lw_cld) # compile first
Expand All @@ -137,45 +122,3 @@ else
end
show(stdout, MIME("text/plain"), trial)
println()

#=
# @suppress_out begin
all_sky(
ClimaComms.context(),
TwoStream,
TwoStream,
TwoStreamLWRTE,
TwoStreamSWRTE,
FT;
use_lut = false,
cldfrac = FT(1),
exfiltrate = true,
)
# end
(; slv_lw, slv_sw, as, lookup_sw, lookup_sw_cld, lookup_lw, lookup_lw_cld) = Infiltrator.exfiltrated
solve_sw!(slv_sw, as, lookup_sw, lookup_sw_cld) # compile first
solve_lw!(slv_lw, as, lookup_lw, lookup_lw_cld) # compile first
@info "all_sky, lw, use_lut=false"
device = ClimaComms.device(ClimaComms.context())
trial = if device isa ClimaComms.CUDADevice
using CUDA
@benchmark CUDA.@sync solve_lw!($slv_lw, $as, $lookup_lw, $lookup_lw_cld)
else
@benchmark solve_lw!($slv_lw, $as, $lookup_lw, $lookup_lw_cld)
end
show(stdout, MIME("text/plain"), trial)
println()
@info "all_sky, sw, use_lut=false"
device = ClimaComms.device(ClimaComms.context())
trial = if device isa ClimaComms.CUDADevice
using CUDA
@benchmark CUDA.@sync solve_sw!($slv_sw, $as, $lookup_sw, $lookup_sw_cld)
else
@benchmark solve_sw!($slv_sw, $as, $lookup_sw, $lookup_sw_cld)
end
show(stdout, MIME("text/plain"), trial)
println()
=#
1 change: 0 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ ClimaParams = "5c42b081-d73a-476f-9059-fd94b934656c"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Coverage = "a2441757-f6aa-5fb2-8edb-039e3f45d037"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Infiltrator = "5903a43b-9cc3-4c30-8d17-598619ec4e9b"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
NCDatasets = "85f8d34a-cbdd-5861-8df4-14fed0d494ab"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Expand Down
30 changes: 22 additions & 8 deletions test/all_sky_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ using Pkg.Artifacts
using NCDatasets

import JET
import Infiltrator
import ClimaComms
@static pkgversion(ClimaComms) >= v"0.6" && ClimaComms.@import_required_backends

Expand All @@ -26,17 +25,16 @@ using RRTMGP.ArtifactPaths
include("reference_files.jl")
include("read_all_sky.jl")

function all_sky(
function setup_all_sky_test(
context,
::Type{SLVLW},
::Type{SLVSW},
::Type{FT},
toler_lw,
toler_sw;
ncol = 128,# repeats col#1 ncol times per RRTMGP example
use_lut::Bool = true,
cldfrac = FT(1),
exfiltrate = false,
toler_sw,
ncol,# repeats col#1 ncol times per RRTMGP example
use_lut::Bool,
cldfrac,
) where {FT <: AbstractFloat, SLVLW, SLVSW}
overrides = (; grav = 9.80665, molmass_dryair = 0.028964, molmass_water = 0.018016)
param_set = RRTMGPParameters(FT, overrides)
Expand Down Expand Up @@ -94,8 +92,24 @@ function all_sky(
inc_flux_diffuse = nothing
swbcs = (cos_zenith, toa_flux, sfc_alb_direct, inc_flux_diffuse, sfc_alb_diffuse)
slv_sw = SLVSW(FT, DA, context, nlay, ncol, swbcs...)

return device, as, lookup_lw, lookup_lw_cld, lookup_sw, lookup_sw_cld, slv_lw, slv_sw, (bot_at_1, nlev)
end

function all_sky(
context,
::Type{SLVLW},
::Type{SLVSW},
::Type{FT},
toler_lw,
toler_sw;
ncol = 128,# repeats col#1 ncol times per RRTMGP example
use_lut::Bool = true,
cldfrac = FT(1),
) where {FT <: AbstractFloat, SLVLW, SLVSW}
device, as, lookup_lw, lookup_lw_cld, lookup_sw, lookup_sw_cld, slv_lw, slv_sw, (bot_at_1, nlev) =
setup_all_sky_test(context, SLVLW, SLVSW, FT, toler_lw, toler_sw, ncol, use_lut, cldfrac)
#------calling solvers
exfiltrate && Infiltrator.@exfiltrate
solve_lw!(slv_lw, as, lookup_lw, lookup_lw_cld)
if device isa ClimaComms.CPUSingleThreaded
JET.@test_opt solve_lw!(slv_lw, as, lookup_lw, lookup_lw_cld)
Expand Down
2 changes: 0 additions & 2 deletions test/all_sky_with_aerosols_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ using Pkg.Artifacts
using NCDatasets

import JET
import Infiltrator
import ClimaComms
@static pkgversion(ClimaComms) >= v"0.6" && ClimaComms.@import_required_backends

Expand Down Expand Up @@ -36,7 +35,6 @@ function all_sky_with_aerosols(
ncol = 128,# repeats col#1 ncol times per RRTMGP example
use_lut::Bool = true,
cldfrac = FT(1),
exfiltrate = false,
) where {FT <: AbstractFloat, SLVLW, SLVSW}
overrides = (; grav = 9.80665, molmass_dryair = 0.028964, molmass_water = 0.018016)
param_set = RRTMGPParameters(FT, overrides)
Expand Down
26 changes: 19 additions & 7 deletions test/clear_sky_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ import JET
import ClimaComms
@static pkgversion(ClimaComms) >= v"0.6" && ClimaComms.@import_required_backends

import Infiltrator

using RRTMGP
using RRTMGP.Vmrs
using RRTMGP.LookUpTables
Expand All @@ -26,17 +24,16 @@ using RRTMGP.ArtifactPaths

include("reference_files.jl")
include("read_rfmip_clear_sky.jl")
#---------------------------------------------------------------
function clear_sky(

function setup_clear_sky_test(
context,
::Type{SLVLW},
::Type{SLVSW},
::Type{VMR},
::Type{FT},
toler_lw,
toler_sw;
toler_sw,
ncol = 100,
exfiltrate = false,
) where {FT, SLVLW, SLVSW, VMR}
overrides = (; grav = 9.80665, molmass_dryair = 0.028964, molmass_water = 0.018016)
param_set = RRTMGPParameters(FT, overrides)
Expand Down Expand Up @@ -78,9 +75,24 @@ function clear_sky(
inc_flux_diffuse = nothing
swbcs = (cos_zenith, toa_flux, sfc_alb_direct, inc_flux_diffuse, sfc_alb_diffuse)
slv_sw = SLVSW(FT, DA, context, nlay, ncol, swbcs...)

return device, as, lookup_lw, lookup_sw, slv_lw, slv_sw, (bot_at_1, nlev, expt_no, cos_zenith)
end
#---------------------------------------------------------------
function clear_sky(
context,
::Type{SLVLW},
::Type{SLVSW},
::Type{VMR},
::Type{FT},
toler_lw,
toler_sw;
ncol = 100,
) where {FT, SLVLW, SLVSW, VMR}
device, as, lookup_lw, lookup_sw, slv_lw, slv_sw, (bot_at_1, nlev, expt_no, cos_zenith) =
setup_clear_sky_test(context, SLVLW, SLVSW, VMR, FT, toler_lw, toler_sw, ncol)
#--------------------------------------------------
# calling longwave and shortwave solvers
exfiltrate && Infiltrator.@exfiltrate
solve_lw!(slv_lw, as, lookup_lw)
if device isa ClimaComms.CPUSingleThreaded
JET.@test_opt solve_lw!(slv_lw, as, lookup_lw)
Expand Down
56 changes: 32 additions & 24 deletions test/gray_atm_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ import ClimaComms
@static pkgversion(ClimaComms) >= v"0.6" && ClimaComms.@import_required_backends

import JET
import Infiltrator
using RRTMGP
using RRTMGP.AngularDiscretizations
using RRTMGP.Fluxes
Expand All @@ -19,12 +18,8 @@ import ClimaParams as CP

# overriding ClimaParams as different precision is needed by RRTMGP

#using Plots

"""
Example program to demonstrate the calculation of longwave radiative fluxes in a model gray atmosphere.
"""
function gray_atmos_lw_equil(context, ::Type{SLVLW}, ::Type{FT}; exfiltrate = false) where {FT <: AbstractFloat, SLVLW}
function setup_gray_atmos_lw_equil_test(context, ::Type{SLVLW}, ::Type{FT}) where {FT <: AbstractFloat, SLVLW}
device = ClimaComms.device(context)
param_set = RRTMGPParameters(FT)
ncol = if device isa ClimaComms.CUDADevice
Expand All @@ -38,14 +33,6 @@ function gray_atmos_lw_equil(context, ::Type{SLVLW}, ::Type{FT}; exfiltrate = fa
pe = FT(9000) # TOA pressure (Pa)
nbnd, ngpt = 1, 1 # # of nbands/g-points (=1 for gray radiation)
nlev = nlay + 1 # # of layers
tb = FT(320) # surface temperature
tstep = 6.0 # timestep in hours
Δt = FT(60 * 60 * tstep) # timestep in seconds
ndays = 365 * 40 # # of simulation days
nsteps = ndays * (24 / tstep) # number of timesteps
temp_toler = FT(0.1) # tolerance for temperature (Kelvin)
flux_grad_toler = FT(1e-5) # tolerance for flux gradient
n_gauss_angles = 1 # for non-scattering calculation
sfc_emis = DA{FT}(undef, nbnd, ncol) # surface emissivity
sfc_emis .= FT(1.0)
inc_flux = nothing # incoming flux
Expand All @@ -60,6 +47,23 @@ function gray_atmos_lw_equil(context, ::Type{SLVLW}, ::Type{FT}; exfiltrate = fa
gray_as = setup_gray_as_pr_grid(context, nlay, lat, p0, pe, otp, param_set, DA)
slv_lw = SLVLW(FT, DA, context, param_set, nlay, ncol, sfc_emis, inc_flux)

return gray_as, slv_lw, (DA, param_set, nlev, nlay, ncol)
end

"""
Example program to demonstrate the calculation of longwave radiative fluxes in a model gray atmosphere.
"""
function gray_atmos_lw_equil(context, ::Type{SLVLW}, ::Type{FT}) where {FT <: AbstractFloat, SLVLW}
gray_as, slv_lw, (DA, param_set, nlev, nlay, ncol) = setup_gray_atmos_lw_equil_test(context, SLVLW, FT)

tb = FT(320) # surface temperature
tstep = 6.0 # timestep in hours
Δt = FT(60 * 60 * tstep) # timestep in seconds
ndays = 365 * 40 # # of simulation days
nsteps = ndays * (24 / tstep) # number of timesteps
temp_toler = FT(0.1) # tolerance for temperature (Kelvin)
flux_grad_toler = FT(1e-5) # tolerance for flux gradient

(; flux_up, flux_dn, flux_net) = slv_lw.flux
(; t_lay, p_lay, t_lev, p_lev) = gray_as
sbc = FT(RRTMGP.Parameters.Stefan(param_set))
Expand All @@ -70,7 +74,6 @@ function gray_atmos_lw_equil(context, ::Type{SLVLW}, ::Type{FT}; exfiltrate = fa
T_ex_lev = DA{FT}(undef, ncol, nlev)
flux_grad = DA{FT}(undef, ncol, nlay)
flux_grad_err = FT(0)
exfiltrate && Infiltrator.@exfiltrate
device = ClimaComms.device(context)
for i in 1:nsteps
# calling the long wave gray radiation solver
Expand Down Expand Up @@ -115,13 +118,7 @@ function gray_atmos_lw_equil(context, ::Type{SLVLW}, ::Type{FT}; exfiltrate = fa
end
end

function gray_atmos_sw_test(
context,
::Type{SLVSW},
::Type{FT},
ncol::Int;
exfiltrate = false,
) where {FT <: AbstractFloat, SLVSW}
function setup_gray_atmos_sw_test(context, ::Type{SLVSW}, ::Type{FT}, ncol::Int) where {FT <: AbstractFloat, SLVSW}
param_set = RRTMGPParameters(FT)
device = ClimaComms.device(context)
DA = ClimaComms.array_type(device)
Expand All @@ -135,7 +132,6 @@ function gray_atmos_sw_test(
Δt = FT(60 * 60 * tstep) # timestep in seconds
ndays = 365 * 40 # # of simulation days
nsteps = ndays * (24 / tstep) # number of timesteps
n_gauss_angles = 1 # for non-scattering calculation
sfc_emis = Array{FT}(undef, nbnd, ncol) # surface emissivity
sfc_emis .= FT(1.0)
deg2rad = FT(π) / FT(180)
Expand Down Expand Up @@ -166,7 +162,19 @@ function gray_atmos_sw_test(
swbcs = (cos_zenith, toa_flux, sfc_alb_direct, inc_flux_diffuse, sfc_alb_diffuse)
slv_sw = SLVSW(FT, DA, context, nlay, ncol, swbcs...)

exfiltrate && Infiltrator.@exfiltrate
return as, slv_sw, (device, param_set, nlev, nlay, ncol, cos_zenith, toa_flux)
end

function gray_atmos_sw_test(context, ::Type{SLVSW}, ::Type{FT}, ncol::Int) where {FT <: AbstractFloat, SLVSW}
as, slv_sw, (device, param_set, nlev, nlay, ncol, cos_zenith, toa_flux) =
setup_gray_atmos_sw_test(context, SLVSW, FT, ncol)

tb = FT(320) # surface temperature
tstep = 6.0 # timestep in hours
Δt = FT(60 * 60 * tstep) # timestep in seconds
ndays = 365 * 40 # # of simulation days
nsteps = ndays * (24 / tstep) # number of timesteps

solve_sw!(slv_sw, as)

τ = Array(slv_sw.op.τ)
Expand Down

0 comments on commit 0c2a8c6

Please sign in to comment.