Skip to content

Commit

Permalink
Add slurm workers for calibration end-to-end test
Browse files Browse the repository at this point in the history
  • Loading branch information
nefrathenrici committed Dec 2, 2024
1 parent 33a2af5 commit 8cbe910
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 173 deletions.
5 changes: 5 additions & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,11 @@ steps:
command: julia --project=calibration/test calibration/test/interface.jl
- label: "end to end test"
command: julia --project=calibration/test calibration/test/e2e_test.jl
agents:
slurm_ntasks: 12
slurm_cpus_per_task: 1
slurm_mem: 96GB
slurm_time: "00:10:00"
artifact_paths: "calibration_end_to_end_test/*"
soft_fail: true

Expand Down
31 changes: 0 additions & 31 deletions .github/workflows/calibration_test.yml

This file was deleted.

3 changes: 2 additions & 1 deletion calibration/test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@ ClimaAtmos = "b2c96348-7fb7-4fe0-8da9-78d88439e717"
ClimaCalibrate = "4347a170-ebd6-470c-89d3-5c705c0cacc2"
ClimaComms = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
ClimaUtilities = "b3f4f4ca-9299-4f7f-bd9b-81e1242a7513"
ClusterManagers = "34f1f09b-3a8b-5176-ab39-66d58a4d544e"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
EnsembleKalmanProcesses = "aa8a2aa5-91d8-4396-bcef-d4f2ec43552d"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6"

[compat]
ClimaCalibrate = "0.0.2 - 0.0.4"
ClimaCalibrate = "0.0.5"
290 changes: 149 additions & 141 deletions calibration/test/e2e_test.jl
Original file line number Diff line number Diff line change
@@ -1,67 +1,11 @@
#= End-to-end test
Runs a perfect model calibration, calibrating on the parameter `astronomical_unit`
with top-of-atmosphere radiative shortwave flux in the loss function.
The calibration is run twice, once on the backend obtained via `get_backend()`
and once on the `JuliaBackend`. The output of each calibration is tested individually
and compared to ensure reproducibility.
=#
using Distributed, ClusterManagers
import ClimaCalibrate as CAL
import ClimaAtmos as CA
import ClimaAnalysis: SimDir, get, slice, average_xy
import CairoMakie
import JLD2
import LinearAlgebra: I
import EnsembleKalmanProcesses as EKP
import Statistics: var, mean
using Test

using Dates

# Debug plots
function scatter_plot(eki::EKP.EnsembleKalmanProcess)
f = CairoMakie.Figure(resolution = (800, 600))
ax = CairoMakie.Axis(
f[1, 1],
ylabel = "Parameter Value",
xlabel = "Top of atmosphere radiative SW flux",
)

g = vec.(EKP.get_g(eki; return_array = true))
params = vec.((EKP.get_ϕ(prior, eki)))

for (gg, uu) in zip(g, params)
CairoMakie.scatter!(ax, gg, uu)
end

CairoMakie.hlines!(ax, [astronomical_unit], linestyle = :dash)
CairoMakie.vlines!(ax, observations, linestyle = :dash)

output = joinpath(output_dir, "scatter.png")
CairoMakie.save(output, f)
return output
end

function param_versus_iter_plot(eki::EKP.EnsembleKalmanProcess)
f = CairoMakie.Figure(resolution = (800, 600))
ax = CairoMakie.Axis(
f[1, 1],
ylabel = "Parameter Value",
xlabel = "Iteration",
)
params = EKP.get_ϕ(prior, eki)
for (i, param) in enumerate(params)
CairoMakie.scatter!(ax, fill(i, length(param)), vec(param))
end

CairoMakie.hlines!(ax, [astronomical_unit]; color = :red, linestyle = :dash)

output = joinpath(output_dir, "param_vs_iter.png")
CairoMakie.save(output, f)
return output
end

# Observation map
function CAL.observation_map(iteration)
single_member_dims = (1,)
G_ensemble = Array{Float64}(undef, single_member_dims..., ensemble_size)
Expand All @@ -86,113 +30,177 @@ function process_member_data(simdir::SimDir)
return slice(average_xy(rsut); time = 30).data
end

# EKI test
function minimal_eki_test(eki)
params = EKP.get_ϕ(prior, eki)
spread = map(var, params)
# TODO: These functions will be moved to ClimaCalibrate
function run_iteration(config, iter, worker_pool)
# Create a channel to collect results
results = Channel{Any}(config.ensemble_size)
@sync begin
for m in 1:(config.ensemble_size)
@async begin
# Get a worker from the pool
worker = take!(worker_pool)
try
@info "Running member $m on worker $worker"
# Run the model and put result in channel
model_config = CAL.set_up_forward_model(m, iter, config)
result = remotecall_fetch(
CAL.run_forward_model,
worker,
model_config,
)
put!(results, (m, result))
catch e
@error "Error running member $m" exception = e
put!(results, (m, e))
finally
# Always return worker to pool
put!(worker_pool, worker)
end
end
end
end

# Spread should be heavily decreased as particles have converged
@test last(spread) / first(spread) < 0.1
# Parameter should be close to true value
@test mean(last(params)) astronomical_unit rtol = 0.02
# Collect all results
ensemble_results = Dict{Int, Any}()
for _ in 1:(config.ensemble_size)
m, result = take!(results)
if result isa Exception
@error "Member $m failed" error = result
else
ensemble_results[m] = result
end
end
results = values(ensemble_results)
all(isa.(results, Exception)) &&
error("Full ensemble for iter $iter failed")

# Process results
G_ensemble = CAL.observation_map(iter)
CAL.save_G_ensemble(config, iter, G_ensemble)
CAL.update_ensemble(config, iter)
iter_path = CAL.path_to_iteration(config.output_dir, iter)
return JLD2.load_object(joinpath(iter_path, "eki_file.jld2"))
end

# Script:

if !(@isdefined backend)
backend = CAL.get_backend()
end
# Check that the wait time for the last hour does not exceed 20 minutes.
# This test schedules many slurm jobs and will be prohibitively slow if the cluster is busy
if backend <: CAL.HPCBackend
wait_times = readchomp(
`sacct --allocations -u esmbuild --starttime now-1hour -o Submit,Start -n`,
)
wait_times = split(wait_times, '\n', keepempty = false)
# Filter jobs that have not been submitted and started
filter!(x -> !(contains(x, "Unknown") || contains(x, "None")), wait_times)

mean_wait_time_in_mins =
mapreduce(+, wait_times; init = 0) do line
t1_str, t2_str = split(line)
t1 = DateTime(t1_str, dateformat"yyyy-mm-ddTHH:MM:SS")
t2 = DateTime(t2_str, dateformat"yyyy-mm-ddTHH:MM:SS")
Dates.value(t2 - t1) / 1000 / 60
end / length(wait_times)

@show mean_wait_time_in_mins

if mean_wait_time_in_mins > 10
@warn """Average wait time for esmbuild is $(round(mean_wait_time_in_mins, digits=2)) minutes. \
Cluster is too busy to run this test, exiting"""
exit()
function calibrate(config, worker_pool)
CAL.initialize(config)
for iter in 0:(config.n_iterations)
(; time) = @timed run_iteration(config, iter, worker_pool)
@info "Iteration $iter time: $time"
end
end

# Paths and setup
const experiment_dir = joinpath(pkgdir(CA), "calibration", "test")
const model_interface =
joinpath(pkgdir(CA), "calibration", "model_interface.jl")
const output_dir = "calibration_end_to_end_test"
include(model_interface)
ensemble_size = 15
addprocs(
SlurmManager(10),
t = "00:20:00",
cpus_per_task = 1,
exeflags = "--project=$(Base.active_project())",
)
worker_pool = WorkerPool(workers())

@everywhere println(string(myid()))

@everywhere begin
import ClimaCalibrate as CAL
import ClimaAtmos as CA

experiment_dir = joinpath(pkgdir(CA), "calibration", "test")
model_interface = joinpath(pkgdir(CA), "calibration", "model_interface.jl")
output_dir = "calibration_end_to_end_test"
include(model_interface)
ensemble_size = 50
obs_path = joinpath(experiment_dir, "observations.jld2")
end

# Generate observations
obs_path = joinpath(experiment_dir, "observations.jld2")
# Generate observations if needed
if !isfile(obs_path)
import JLD2
@info "Generating observations"
config = CA.AtmosConfig(joinpath(experiment_dir, "model_config.yml"))
simulation = CA.get_simulation(config)
atmos_config = CA.AtmosConfig(joinpath(experiment_dir, "model_config.yml"))
simulation = CA.get_simulation(atmos_config)
CA.solve_atmos!(simulation)
observations = Vector{Float64}(undef, 1)
observations .= process_member_data(SimDir(simulation.output_dir))
JLD2.save_object(obs_path, observations)
end

# Initialize experiment data
astronomical_unit = 149_597_870_000
observations = JLD2.load_object(obs_path)
noise = 0.1 * I
n_iterations = 4
prior = CAL.get_prior(joinpath(experiment_dir, "prior.toml"))
experiment_config = CAL.ExperimentConfig(;
n_iterations,
ensemble_size,
observations,
noise,
output_dir,
prior,
)

@info "Running calibration E2E test" backend
if backend <: CAL.HPCBackend
test_eki = CAL.calibrate(
backend,
experiment_config;
hpc_kwargs = CAL.kwargs(time = 15),
model_interface,
verbose = true,
@everywhere begin
import JLD2
import LinearAlgebra: I
astronomical_unit = 149_597_870_000
observations = JLD2.load_object(obs_path)
noise = 0.1 * I
n_iterations = 10
prior = CAL.get_prior(joinpath(experiment_dir, "prior.toml"))

config = CAL.ExperimentConfig(;
n_iterations,
ensemble_size,
observations,
noise,
output_dir,
prior,
)
else
test_eki = CAL.calibrate(backend, experiment_config)
end

scatter_plot(test_eki)
param_versus_iter_plot(test_eki)
calibrate(config, worker_pool)

@testset "Test Calibration on $backend" begin
minimal_eki_test(test_eki)
end
import EnsembleKalmanProcesses as EKP
import Statistics: var, mean
using Test
import CairoMakie

function scatter_plot(eki::EKP.EnsembleKalmanProcess)
f = CairoMakie.Figure(resolution = (800, 600))
ax = CairoMakie.Axis(
f[1, 1],
ylabel = "Parameter Value",
xlabel = "Top of atmosphere radiative SW flux",
)

g = vec.(EKP.get_g(eki; return_array = true))
params = vec.((EKP.get_ϕ(prior, eki)))

for (gg, uu) in zip(g, params)
CairoMakie.scatter!(ax, gg, uu)
end

# Run calibration
julia_eki = CAL.calibrate(CAL.JuliaBackend, experiment_config)
CairoMakie.hlines!(ax, [astronomical_unit], linestyle = :dash)
CairoMakie.vlines!(ax, observations, linestyle = :dash)

@testset "Julia-only comparison calibration" begin
minimal_eki_test(julia_eki)
output = joinpath(output_dir, "scatter.png")
CairoMakie.save(output, f)
return output
end

@testset "Compare $backend output to JuliaBackend" begin
for (uu, slurm_uu) in zip(EKP.get_u(julia_eki), EKP.get_u(test_eki))
@test uu slurm_uu rtol = 0.02
function param_versus_iter_plot(eki::EKP.EnsembleKalmanProcess)
f = CairoMakie.Figure(resolution = (800, 600))
ax = CairoMakie.Axis(
f[1, 1],
ylabel = "Parameter Value",
xlabel = "Iteration",
)
params = EKP.get_ϕ(prior, eki)
for (i, param) in enumerate(params)
CairoMakie.scatter!(ax, fill(i, length(param)), vec(param))
end

CairoMakie.hlines!(ax, [astronomical_unit]; color = :red, linestyle = :dash)

output = joinpath(output_dir, "param_vs_iter.png")
CairoMakie.save(output, f)
return output
end

eki = JLD2.load_object(joinpath(output_dir, "iteration_011", "eki_file.jld2"))
scatter_plot(eki)
param_versus_iter_plot(eki)

params = EKP.get_ϕ(prior, eki)
spread = map(var, params)

# Spread should be heavily decreased as particles have converged
@test last(spread) / first(spread) < 0.1
# Parameter should be close to true value
@test mean(last(params)) astronomical_unit rtol = 0.02

0 comments on commit 8cbe910

Please sign in to comment.